diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index d89f0cdd91e52..0000000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,22 +0,0 @@ -# Number of days of inactivity before an issue becomes stale -daysUntilStale: 30 - -# Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 - -# Issues with these labels will never be considered stale -exemptLabels: - - contributions welcome - - feature request - - regression - -# Label to use when marking an issue as stale -staleLabel: stale - -# Comment to post when marking an issue as stale. Set to `false` to disable -markComment: > - This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details. - -# Comment to post when closing a stale issue. Set to `false` to disable -closeComment: > - This issue has been automatically closed due to inactivity. Please reactivate if further support is needed. diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000000..67d8550d44204 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,34 @@ +name: Close stale issues +on: + # Allows you to dictate when you want this workflow to run using cron syntax (times in UTC) + schedule: + - cron: "0 15 * * *" + # Allows you to run this workflow manually from the Actions tab + # workflow_dispatch: + +jobs: + close-stale-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v4.1.1 + with: + # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale + exempt-issue-labels: contributions welcome, feature request, regression + # Number of days without activity before the actions/stale action labels an issue + days-before-issue-stale: 30 + # Number of days without activity before the actions/stale action closes an issue + days-before-issue-close: 7 + # Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale + stale-issue-label: "stale" + # Comment that you want to add to issues that are labeled by the actions/stale action + stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details." + # Comment that you want to add to issues that are closed by the actions/stale action + close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed." + # If you never want this action to label PRs, set this value to -1 + days-before-pr-stale: -1 + # If you never want this action to close PRs, set this value to -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f9501253661a2..6b0e3659bd234 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -26,7 +26,7 @@ "component": { "type": "git", "git": { - "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", + "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -192,16 +192,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "onnx" - } - }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index f81a268d38dff..94181448fd21c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1282,14 +1282,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP16) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_U8) - add_definitions(-DOPENVINO_CONFIG_VPUX_U8=1) - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1310,16 +1302,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP32_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP32=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_FP16_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") diff --git a/cmake/deps.txt b/cmake/deps.txt index 631d326e2ba5b..aeb7c05080abb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index 6a20ba82b439e..b86cc54efce19 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828 +Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 003012f8da071..043789c36c327 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -38,6 +38,8 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -246,4 +248,4 @@ install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index bf9adbaefabcc..a9a78668b4810 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*" ) + file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*" + ) file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py" ) @@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/graph_optimizers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton/kernel COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils @@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_graph_optimizers_srcs} + $/onnxruntime/training/ortmodule/graph_optimizers/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ort_triton_srcs} $/onnxruntime/training/ort_triton/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index de1458c120016..6ccf063c71290 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,6 +48,9 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" + "math/gemm_float8.cc" + "math/gemm_float8.cu" + "math/gemm_float8.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" @@ -103,6 +106,9 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/sharding.cc") list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") endif() set(provider_excluded_files diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1a76c18a6a8e0..ed1049b0bd73a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -40,6 +40,7 @@ Do not modify directly.* * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu + * com.microsoft.GemmFloat8 * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm @@ -94,6 +95,7 @@ Do not modify directly.* * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -2137,6 +2139,71 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmFloat8** + + Generic Gemm for float and float 8. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : string
+
Activation function, RELU or GELU or NONE (default).
+
alpha : float
+
Scalar multiplier for the product of input tensors A * B.
+
beta : float
+
Scalar multiplier for the product of input bias C.
+
dtype : int
+
Output Type. Same definition as attribute 'to' for operator Cast.
+
transA : int
+
Whether A should be transposed. Float 8 only supprted transA=0.
+
transB : int
+
Whether B should be transposed. Float 8 only supprted transB=1.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+
B : TB
+
Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+
C (optional) : TC
+
Input tensor C.
+
scaleA (optional) : TS
+
Scale of tensor A if A is float 8 tensor
+
scaleB (optional) : TS
+
Scale of tensor B if B is float 8 tensor
+
scaleY (optional) : TS
+
Scale of the output tensor if A or B is float 8.
+
+ +#### Outputs + +
+
Y : TR
+
Output tensor of shape (M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input A.
+
TB : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input B.
+
TC : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input C.
+
TR : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to result type.
+
TS : tensor(float)
+
Constrain type for all input scales (scaleA, scaleB, scaleY).
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. @@ -2276,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2516,6 +2583,7 @@ This version of the operator has been available since version 1 of the 'com.micr Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. @@ -5017,6 +5085,72 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 84249df92231b..dcdf73cbdbf08 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -522,10 +522,8 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -696,39 +694,26 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -846,6 +831,7 @@ Do not modify directly.* |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -875,6 +861,7 @@ Do not modify directly.* |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index f12c01d278dca..6ef16e1378139 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,6 @@ OpenVINOâ„¢ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs Installation ------------ @@ -22,7 +21,6 @@ This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs ``pip3 install onnxruntime-openvino`` diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 8f2b5af870506..680ce1cc5b9a2 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -25,13 +25,14 @@ struct OrtTensorRTProviderOptionsV2 { int trt_dla_core{0}; // DLA core number. Default 0 int trt_dump_subgraphs{0}; // dump TRT subgraph. Default 0 = false, nonzero = true int trt_engine_cache_enable{0}; // enable engine caching. Default 0 = false, nonzero = true - const char* trt_engine_cache_path{nullptr}; // specify engine cache path + const char* trt_engine_cache_path{nullptr}; // specify engine cache path, defaults to the working directory int trt_engine_decryption_enable{0}; // enable engine decryption. Default 0 = false, nonzero = true const char* trt_engine_decryption_lib_path{nullptr}; // specify engine decryption library path int trt_force_sequential_engine_build{0}; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true int trt_context_memory_sharing_enable{0}; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true int trt_layer_norm_fp32_fallback{0}; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true int trt_timing_cache_enable{0}; // enable TensorRT timing cache. Default 0 = false, nonzero = true + const char* trt_timing_cache_path{nullptr}; // specify timing cache path, if none is provided the trt_engine_cache_path is used int trt_force_timing_cache{0}; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true int trt_detailed_build_log{0}; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true int trt_build_heuristics_enable{0}; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4a63018f870a6..729a302f3dd0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -611,7 +611,7 @@ typedef struct OrtMIGraphXProviderOptions { typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus OrtOpenVINOProviderOptions() : device_type{}, - enable_vpu_fast_compile{}, + enable_npu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, @@ -624,7 +624,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default @@ -4605,6 +4605,10 @@ struct OrtCustomOp { OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 467eb31ee2c8e..92c25d8688b66 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2228,6 +2228,8 @@ struct ShapeInferContext { using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -2280,6 +2282,14 @@ struct CustomOpBase : OrtCustomOp { } SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -2348,6 +2358,9 @@ struct CustomOpBase : OrtCustomOp { protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index b12221e56b79f..443710884743a 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp { PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { + const char* execution_provider, + int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -837,6 +840,16 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::KernelCompute = {}; OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; } const std::string op_name_; @@ -844,6 +857,9 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFn compute_fn, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_(compute_fn), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFnReturnStatus compute_fn_return_status, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_return_status_(compute_fn_return_status), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { }; OrtLiteCustomStruct(const char* op_name, - const char* execution_provider) : OrtLiteCustomOp(op_name, - execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { @@ -1049,25 +1070,31 @@ template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, void (*custom_compute_fn)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, Status (*custom_compute_fn_v2)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, - const char* execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct; - return std::make_unique(op_name, execution_provider).release(); + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom diff --git a/js/web/.npmignore b/js/web/.npmignore index 16f487f5ff4cd..0f018f525a8d6 100644 --- a/js/web/.npmignore +++ b/js/web/.npmignore @@ -4,6 +4,26 @@ /dist/**/*.report.html +# We remove some of the files in NPM packages because restrictions in jsdelivr: +# +# "Packages larger than 150 MB or single files larger than 20 MB (in the case of GitHub) are not supported" +# +# from https://www.jsdelivr.com/documentation +# +# We only include development build in the NPM package for the following targets: +# - /dist/ort.js +# - /dist/ort.all.js +# +/dist/cjs/ort.js +/dist/esm/ort.js +/dist/cjs/ort.all.js +/dist/esm/ort.all.js +/dist/**/ort.wasm.js +/dist/**/ort.wasm-core.js +/dist/**/ort.webgl.js +/dist/**/ort.webgpu.js +/dist/**/ort.training.wasm.js + /types/ karma.conf.js diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index c5c27a4318049..6060271ced156 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,6 +7,9 @@ // So we import code inside the if-clause to allow bundler remove the code safely. export * from 'onnxruntime-common'; +import * as ort from 'onnxruntime-common'; +export default ort; + import {registerBackend, env} from 'onnxruntime-common'; import {version} from './version'; diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5d66caf77f08f..eb40da048835e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -126,14 +126,14 @@ export class WebGpuBackend { */ kernels: Map unknown) | undefined, unknown]]>; - commandEncoder: GPUCommandEncoder|null = null; - computePassEncoder: GPUComputePassEncoder|null = null; + private commandEncoder: GPUCommandEncoder|null = null; + private computePassEncoder: GPUComputePassEncoder|null = null; pendingDispatchNumber = 0; - supportTimestampQuery = false; - profilingQuerySet: GPUQuerySet; - profilingQueryData: GpuData; - profilingTimeBase?: bigint; + queryData?: GpuData; + querySet?: GPUQuerySet; + querySetCount = 2; + queryTimeBase?: bigint; env: Env; @@ -168,11 +168,9 @@ export class WebGpuBackend { }, requiredFeatures, }; - // WebGPU Spec: Timestamp Queries Inside Passes - // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md - if (adapter.features.has('timestamp-query-inside-passes')) { - this.supportTimestampQuery = true; - requiredFeatures.push('timestamp-query-inside-passes' as GPUFeatureName); + + if (adapter.features.has('timestamp-query')) { + requiredFeatures.push('timestamp-query'); } if (adapter.features.has('shader-f16')) { requiredFeatures.push('shader-f16'); @@ -197,21 +195,14 @@ export class WebGpuBackend { } }; - if (this.supportTimestampQuery) { - this.profilingQuerySet = this.device.createQuerySet({ - type: 'timestamp', - count: 2, - }); - } - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); } dispose(): void { - // currently, we do not do anything in this function. In all known use cases, we don't have the requirement to - // actually dispose the WebGpuBackend instance, because it's always used as a singleton. - // - // revisit this place if we get real requirement to dispose the instance. + if (typeof this.querySet !== 'undefined') { + this.querySet.destroy(); + } + this.gpuDataManager.dispose(); } getCommandEncoder(): GPUCommandEncoder { @@ -223,7 +214,22 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { - this.computePassEncoder = this.getCommandEncoder().beginComputePass(); + const computePassDescriptor: GPUComputePassDescriptor = {}; + if (this.isQueryEnabled()) { + if (typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.querySetCount, + }); + } + computePassDescriptor.timestampWrites = { + querySet: this.querySet, + beginningOfPassWriteIndex: 0, + endOfPassWriteIndex: 1, + }; + } + + this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -245,6 +251,14 @@ export class WebGpuBackend { } } + isQueryEnabled(): boolean { + if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') { + return true; + } else { + return false; + } + } + /** * run a WebGPU program. * @param program a ProgramInfo instance diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index d4dbad79e613e..ec651ce34e8c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { @@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; + const components = getMaxComponents(cols); + const packedCols = cols / components; + const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`; + + const maxVector = (name: string, components: number) => { + if (components === 4) { + return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; + } else if (components === 2) { + return `max(${name}.x, ${name}.y)`; + } else if (components === 3) { + return `max(max(${name}.x, ${name}.y), ${name}.z)`; + } + + return name; + }; // 6.2.4 in wgsl spec - const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;'; + const threadMaxDecl = + dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (_shaderHelper: ShaderHelper) => ` - var rowMaxShared : ${dataType}; - var rowSumShared : ${dataType}; - var threadShared : array<${dataType}, ${WG}>; + var rowMaxShared : ${valueType}; + var rowSumShared : ${valueType}; + var threadShared : array<${valueType}, ${WG}>; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var result : array<${dataType}>; + @group(0) @binding(0) var x : array<${valueType}>; + @group(0) @binding(1) var result : array<${valueType}>; - fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} { let index = row * row_stride + col; return x[index]; } - fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) { let index = row * row_stride + col; result[index] = value; } @@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut let lindex = i32(local_id.x); const wg = ${WG}; let row = gindex / wg; - let cols = ${cols}; - let row_stride : i32 = ${cols}; + let cols = ${packedCols}; + let row_stride : i32 = ${packedCols}; // find the rows max ${threadMaxDecl} @@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowMaxShared = threadShared[0]; + rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)}); } workgroupBarrier(); // find the rows sum - var threadSum: ${dataType} = 0.0; + var threadSum = ${valueType}(0.0); for (var col = lindex; col < cols; col += wg) { let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); threadSum += subExp; @@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowSumShared = threadShared[0]; + rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)}); } workgroupBarrier(); diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 5c5a07d90d34a..341e6edf26cc8 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,14 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); - const profilingEnabled = this.backend.supportTimestampQuery && this.backend.env.webgpu.profilingMode === 'default'; - if (profilingEnabled) { - // profiling write start timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 0); - } - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { @@ -65,24 +57,20 @@ export class ProgramManager { this.backend.pendingDispatchNumber++; - if (profilingEnabled) { - // profiling write end timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 1); - if (this.backend.profilingQueryData == null) { - this.backend.profilingQueryData = + if (this.backend.isQueryEnabled()) { + if (typeof this.backend.queryData === 'undefined') { + this.backend.queryData = this.backend.gpuDataManager.create( // eslint-disable-next-line no-bitwise - this.backend.gpuDataManager.create(16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); + this.backend.querySetCount * 8, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); } - // eslint-disable-next-line no-bitwise - const syncData = this.backend.gpuDataManager.create(16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + const syncData = this.backend.gpuDataManager.create( + // eslint-disable-next-line no-bitwise + this.backend.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); this.backend.endComputePass(); - this.backend.getCommandEncoder().resolveQuerySet( - this.backend.profilingQuerySet, 0, 2, this.backend.profilingQueryData.buffer, 0); + this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet, 0, 2, this.backend.queryData.buffer, 0); this.backend.getCommandEncoder().copyBufferToBuffer( - this.backend.profilingQueryData.buffer, 0, syncData.buffer, 0, 16); + this.backend.queryData.buffer, 0, syncData.buffer, 0, this.backend.querySetCount * 8); this.backend.flush(); const kernelId = this.backend.currentKernelId!; @@ -96,12 +84,12 @@ export class ProgramManager { syncData.buffer.unmap(); - if (typeof this.backend.profilingTimeBase === 'undefined') { - this.backend.profilingTimeBase = startTimeU64; + if (typeof this.backend.queryTimeBase === 'undefined') { + this.backend.queryTimeBase = startTimeU64; } - const startTime = Number(startTimeU64 - this.backend.profilingTimeBase); - const endTime = Number(endTimeU64 - this.backend.profilingTimeBase); + const startTime = Number(startTimeU64 - this.backend.queryTimeBase); + const endTime = Number(endTimeU64 - this.backend.queryTimeBase); if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { throw new RangeError('incorrect timestamp range'); diff --git a/js/web/package.json b/js/web/package.json index 15f13600c050e..7271fed99d709 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -68,14 +68,8 @@ ".": { "node": "./dist/ort.node.min.js", "default": { - "import": { - "development": "./dist/esm/ort.js", - "default": "./dist/esm/ort.min.js" - }, - "require": { - "development": "./dist/cjs/ort.js", - "default": "./dist/cjs/ort.min.js" - }, + "import": "./dist/esm/ort.min.js", + "require": "./dist/cjs/ort.min.js", "default": { "development": "./dist/ort.js", "default": "./dist/ort.min.js" @@ -83,88 +77,37 @@ } }, "./experimental": { - "import": { - "development": "./dist/esm/ort.all.js", - "default": "./dist/esm/ort.all.min.js" - }, - "require": { - "development": "./dist/cjs/ort.all.js", - "default": "./dist/cjs/ort.all.min.js" - }, + "import": "./dist/esm/ort.all.min.js", + "require": "./dist/cjs/ort.all.min.js", "default": { "development": "./dist/ort.all.js", "default": "./dist/ort.all.min.js" } }, "./wasm": { - "import": { - "development": "./dist/esm/ort.wasm.js", - "default": "./dist/esm/ort.wasm.min.js" - }, - "require": { - "development": "./dist/cjs/ort.wasm.js", - "default": "./dist/cjs/ort.wasm.min.js" - }, - "default": { - "development": "./dist/ort.wasm.js", - "default": "./dist/ort.wasm.min.js" - } + "import": "./dist/esm/ort.wasm.min.js", + "require": "./dist/cjs/ort.wasm.min.js", + "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { - "import": { - "development": "./dist/esm/ort.wasm-core.js", - "default": "./dist/esm/ort.wasm-core.min.js" - }, - "require": { - "development": "./dist/cjs/ort.wasm-core.js", - "default": "./dist/cjs/ort.wasm-core.min.js" - }, - "default": { - "development": "./dist/ort.wasm-core.js", - "default": "./dist/ort.wasm-core.min.js" - } + "import": "./dist/esm/ort.wasm-core.min.js", + "require": "./dist/cjs/ort.wasm-core.min.js", + "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { - "import": { - "development": "./dist/esm/ort.webgl.js", - "default": "./dist/esm/ort.webgl.min.js" - }, - "require": { - "development": "./dist/cjs/ort.webgl.js", - "default": "./dist/cjs/ort.webgl.min.js" - }, - "default": { - "development": "./dist/ort.webgl.js", - "default": "./dist/ort.webgl.min.js" - } + "import": "./dist/esm/ort.webgl.min.js", + "require": "./dist/cjs/ort.webgl.min.js", + "default": "./dist/ort.webgl.min.js" }, "./webgpu": { - "import": { - "development": "./dist/esm/ort.webgpu.js", - "default": "./dist/esm/ort.webgpu.min.js" - }, - "require": { - "development": "./dist/cjs/ort.webgpu.js", - "default": "./dist/cjs/ort.webgpu.min.js" - }, - "default": { - "development": "./dist/ort.webgpu.js", - "default": "./dist/ort.webgpu.min.js" - } + "import": "./dist/esm/ort.webgpu.min.js", + "require": "./dist/cjs/ort.webgpu.min.js", + "default": "./dist/ort.webgpu.min.js" }, "./training": { - "import": { - "development": "./dist/esm/ort.training.wasm.js", - "default": "./dist/esm/ort.training.wasm.min.js" - }, - "require": { - "development": "./dist/cjs/ort.training.wasm.js", - "default": "./dist/cjs/ort.training.wasm.min.js" - }, - "default": { - "development": "./dist/ort.training.wasm.js", - "default": "./dist/ort.training.wasm.min.js" - } + "import": "./dist/esm/ort.training.wasm.min.js", + "require": "./dist/cjs/ort.training.wasm.min.js", + "default": "./dist/ort.training.wasm.min.js" } }, "types": "./types.d.ts", diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc index 945c3aebce579..d0abf58922f88 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc @@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const { aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size, dlpack_outputs.get()); for (size_t i = 0; i < output_size; ++i) { - ORT_RETURN_IF_ERROR( - p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + if (dlpack_outputs[i]) { + ORT_RETURN_IF_ERROR( + p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index be9650d96b004..d72868cd8fa9f 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index); +typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); - p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); + void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); + p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) { - ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); + bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; + IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5184dd99309b1..0fd8790e0d29d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; - int num_splits; // number of splits for splitkv AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0dc7de0e9e519..bf6431cf1afb2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,8 +135,24 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { @@ -279,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb1..eb9e6d5c62467 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d6..3e78978c3cc43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index ff7a22d253a5b..89a27c4d2b0d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits, bool new_kv, bool is_sm8x) { + int max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea return 1; } +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 0a0328edb0059..58f4304251872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,6 +31,7 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, ); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 784335a124c75..82dfa59b8f8e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -123,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 65d19d4473872..67d750aeac11a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -116,22 +116,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t out_accum_bytes = 0; size_t seqlens_k_bytes = 0; if (use_flash_attention) { + // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffers - parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, - parameters.head_size, device_prop.multiProcessorCount, 128, false, - device_prop.major == 8 && device_prop.minor > 0); - if (parameters.num_splits > 1) { - // softmax_lse_accum buffer - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); - // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(parameters.head_size, 32); - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); - } + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; // seqlens_k buffer if (past_key != nullptr) { seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e3f53ca6a63cb..ebd66d8c6528e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -291,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index e4b09b00f030c..973ef8d304e2e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,11 +51,11 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192}; constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { for (size_t i = 0; i < kNumOfSizes; ++i) { @@ -63,14 +63,13 @@ int NextSize(int x) { return kSizes[i]; } } - return kMaxSize; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias, - const T* gamma, const T* beta, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && reinterpret_cast(output) % alignment == 0 && reinterpret_cast(sum_output) % alignment == 0 && reinterpret_cast(input) % alignment == 0 && @@ -78,8 +77,8 @@ bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, cons reinterpret_cast(bias) % alignment == 0 && reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace @@ -187,8 +186,14 @@ void LaunchSkipLayerNormKernel( int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); - bool flag_vec4 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); #define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ SkipLayerNormKernelSmall<<>>( \ @@ -198,39 +203,42 @@ void LaunchSkipLayerNormKernel( SkipLayerNormKernel<<>>( \ output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - constexpr int block_size = 256; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 8 * 256) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ } break switch (next_size) { - case kSizes[0]: { - constexpr int block_size = kSizes[0]; - // TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size. - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); - break; - } + CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); CASE_NEXT_SIZE(kSizes[3]); CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); - // kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize. + CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); default: { constexpr int block_size = 256; LAUNCH_SKIP_LAYER_NORM_KERNEL(); diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc new file mode 100644 index 0000000000000..3cfa3ab959343 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_expand.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/expand.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {} + +template +Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape) + .IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h new file mode 100644 index 0000000000000..dedb1bdc5aa36 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedExpand final : public DistributedKernel { + public: + explicit DistributedExpand(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc new file mode 100644 index 0000000000000..a0ac40defbee7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc @@ -0,0 +1,861 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reshape.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +// Return true if src_shape[src_begin:src_end] is the same as +// dst_shape[dst_begin:dst_end]. Otherwise, return false. +// TODO: replace std::vector with gsl::span. +bool CompareSubVectors( + const std::vector& src_shape, + const std::vector& dst_shape, + size_t src_begin, size_t src_end, + size_t dst_begin, size_t dst_end) { + if (src_end - src_begin != dst_end - dst_begin) { + // Sub-vectors have different lengths. + return false; + } + for (size_t src_index = src_begin, dst_index = dst_begin; + src_index < src_end && dst_index < dst_end; + ++src_index, ++dst_index) { + if (src_shape[src_index] != dst_shape[dst_index]) { + // Sub-vectors have different elements. + return false; + } + } + // Sub-vectors have same length and same elements. + return true; +} + +// TODO: replace std::vector with gsl::span. +std::tuple IsTwoAxisFusion( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether two consecutive axes are fused. + // - size_t: the axis in destination shape formed by fusing two source axes. + // - size_t: the first axis fused. + // - size_t: the length of fusion. In two-axis fusion considered by this + // function, the length of fusion is always 2. + const size_t src_rank = src_shape.size(); + const size_t dst_rank = dst_shape.size(); + if (src_rank < 2 || dst_rank < 1) { + return std::make_tuple(false, -1, -1, -1); + } + if (src_rank - 1 != dst_rank) { + return std::make_tuple(false, -1, -1, -1); + } + for (size_t i_src = 0; i_src < src_rank; ++i_src) { + if (i_src + 1 > src_rank - 1) { + // We are at src_shape[i] and we need + // src_shape[i + 1] to fuse. + // If we are at the last axis, we cannot fuse. + break; + } + const int64_t prod = src_shape[i_src] * src_shape[i_src + 1]; + + for (size_t i_dst = 0; i_dst < dst_rank; ++i_dst) { + // Check if shape[i_src:i_src+2] (i.e., shape[i_src] and shape[i_src+1]) + // for source tensor are fused into shape[i_dst] for destination tensor. + if (prod != dst_shape[i_dst]) { + continue; + } + // Check if corresponding dimensions before fusion area + // are the same. + const bool prefix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[0:i_src]. + 0, i_src, + // Represent dst_shape[0:i_dst]. + 0, i_dst); + const bool suffix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[i_src+2:]. + i_src + 2, src_rank, + // Represent dst_shape[i_dst+1:]. + i_dst + 1, dst_rank); + if (prefix_shape_match && suffix_shape_match) { + return std::make_tuple( + true, i_dst, i_src, 2); + } + } + } + return std::make_tuple(false, 0, 0, 0); +} + +std::tuple IsTwoAxisDecomposition( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether one source axis is decomposed into two consecutive destination axes. + // - size_t: the axis in source shape decomposed into two consecutive destination axes. + // - size_t: the first axis the source axis decomposed into. + // - size_t: the number of decomposed axes. It's always 2 in this function. + return IsTwoAxisFusion(dst_shape, src_shape); +} + +std::vector RepeatVector(const std::vector& vec, int64_t repeat) { + std::vector new_vec; + for (int64_t i = 0; i < repeat; ++i) { + new_vec.insert(new_vec.end(), vec.begin(), vec.end()); + } + return new_vec; +} + +DeviceMesh CreateInterleaveDeviceMesh( + const DeviceMesh& source_mesh, const int64_t repeat) { + // Given a 1-D device mesh [0, 1] and repeat=2, + // return 1-D device mesh [0, 1, 0, 1]. + if (source_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source mesh shape 1-D."); + } + + // Mesh to return. + DeviceMesh new_mesh; + + std::vector& elements = new_mesh.device_mesh_elements; + for (int64_t i = 0; i < repeat; ++i) { + elements.insert( + elements.end(), + source_mesh.device_mesh_elements.begin(), + source_mesh.device_mesh_elements.end()); + } + + // source mesh must be 1-D so we only care its 1st dimension. + new_mesh.device_mesh_shape.push_back(source_mesh.device_mesh_shape[0] * repeat); + + return new_mesh; +} + +std::tuple ComputeNativeSpecForTwoAxisFusion( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t fused_axis_in_src, + const int64_t fusion_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example: RS[0], shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1, 0, 1] + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + ORT_ENFORCE(src_spec.CountShardingAxes() == 1, "Tensor to be reshaped has too many sharding axes."); + ORT_ENFORCE(src_spec.device_mesh.device_mesh_shape.size() == 1, "Source device mesh be 1-D."); + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src)) { + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example 1: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 0, 0, 0, 0], (device assignment) + // [1, 1, 1, 1, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[ 0, 1, 2, 3, 4, 5, 6, 7]] + // - Device 1's local tensor (shape: [2, 4]). + // [[ 8, 9, 10, 11, 12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from shape [2, 4] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 2: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 4, 5], + // [ 6, 7]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 3: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [1, 1], + // [1, 1], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 8, 9], + // [10, 11]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 4, 5], + // [ 6, 7], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor. + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1, 0, 1] + + // Reuse original device mesh but shard the fusion axis in output tensor. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), src_spec.device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src + 1)) { + // Example 1 of determining native output sharding spec: + // - logical input shape: [3, 4] + // - logical output shape: [12] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 1, 0, 1], (device assignment) + // [0, 1, 0, 1], + // [0, 1, 0, 1]] + // [[0, 1, 2, 3], (values) + // [4, 5, 6, 7], + // [8, 9, 10, 11]], + // - Device 0's local tensor. + // [[0, 0], + // [0, 0], + // [0, 0]] + // [[0, 2], + // [4, 6], + // [8, 10]], + // - Device 1's local tensor. + // [[1, 1], + // [1, 1], + // [1, 1]] + // [[1, 3], + // [5, 7], + // [9, 11]], + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [6] by fusing both axes in shape [3, 2]. + // 3. Run local reshape (reshape from [3, 2] to [6]): + // - Device 0's local output tensor. + // [0, 0, 0, 0, 0, 0] + // [0, 2, 4, 6, 8, 10] + // - Device 1's local output tensor. + // [1, 1, 1, 1, 1, 1] + // [1, 3, 5, 7, 9, 11] + // 4. Determine native output sharding spec by comparing local output tensors and logical tensor. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 2 of determining native output sharding spec: + // - logical input shape: [3, 8] + // - logical output shape: [24] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1], + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15], + // [16, 17, 18, 19, 20, 21, 22, 23]] + // - Device 0's local tensor (shape: [3, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13], + // [16, 17, 20, 21]] + // - Device 1's local tensor (shape: [3, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15], + // [18, 19, 22, 23]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [12] by fusing both axes in shape [3, 4]. + // 3. Run local reshape (reshape from [3, 4] to [12]): + // - Device 0's local output tensor . + // [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21] + // - Device 1's local output tensor . + // [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = . + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 3: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 4, 5, 8, 9, 12, 13] + // - Device 1's local output tensor . + // [ 2, 3, 6, 7, 10, 11, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 2 + // + // Example 4: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 1, 1, 1, 1], (device assignment) + // [0, 0, 0, 0, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 2, 3], + // [ 8, 9, 10, 11]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 4, 5, 6, 7], + // [12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor . + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1] = [0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1] * (first fused dimension) = [0, 1] * 2 = [0, 1, 0, 1] + + // The output device mesh is the repeats of the original device. + // Let's use Python syntax. If the original device mesh is [0, 1, 0, 1], and + // the first fused dimension is 3, then the output device mesh is [0, 1, 0, 1] * 3. + auto dst_device_mesh = DeviceMesh::Create1D( + src_spec.device_mesh.device_mesh_elements, + src_shape[fused_axis_in_src]); + // Sharding happens in the fusion axis with the new device mesh. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), dst_device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && (src_spec.GetPartitionAxis() < fused_axis_in_src || src_spec.GetPartitionAxis() > fused_axis_in_src + 1)) { + // It's two-axis fusion but the fused axes is not sharded. + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + auto dst_spec = TensorPartitionSpec::CreateByDropOneAxis( + src_spec, fused_axis_in_src + 1); + return std::make_tuple(true, dst_spec); + } else { + return std::make_tuple(false, TensorPartitionSpec()); + } +} + +// Arguments: +// - device_elements: a vector of device IDs. +// It should only contain unique device IDs or +// repeats of a list of unique device IDs. Otherwise, +// (0, 0) is returned. +// Returns: +// - count per device ID (all device IDs should have the same count) +// - number of unique device IDs +// Examples: +// - [0, 1] -> (2, 1) +// - [0, 1, 2, 0, 1, 2] -> (2, 3) +std::tuple ComputeRepeatAndRepeatStride( + const std::vector& device_elements) { + int64_t first_device_id = device_elements.at(0); + int64_t first_device_id_count = 0; + for (size_t i = 0; i < device_elements.size(); ++i) { + if (device_elements.at(i) == first_device_id) { + ++first_device_id_count; + } + } + size_t repeat_stride = device_elements.size() / first_device_id_count; + + // Check if the device mesh pattern is supported. + // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1]. + // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0]. + for (size_t repeat = 0; repeat < first_device_id_count; ++repeat) { + for (size_t device_id = 0; device_id < repeat_stride; ++device_id) { + ORT_ENFORCE( + device_elements.at(repeat * repeat_stride + device_id) == device_elements.at(device_id), + "Unsupported device mesh pattern."); + } + } + + // If device_mesh=[0, 1, 2, 0, 1, 2], returns (2, 3), which means + // - each device repeats twice for "2" in (2, 3). + // - there are 3 unique devices for "3" in (2, 3). + return std::make_tuple(first_device_id_count, repeat_stride); +} + +std::tuple ComputeNativeSpecForTwoAxisDecomposition( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t decomposed_axis_in_src, + const int64_t decomposition_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0], shape=[8], device_mesh=[0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1] -> RS[0] + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> RS[0] + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RS[0]RR + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RRS[0]R + if (src_spec.CountShardingAxes() != 1) { + throw std::runtime_error("Too many sharding axes."); + } + if (src_spec.device_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source device mesh be 1-D."); + } + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.OnlyShardAxis(decomposed_axis_in_src)) { + const int64_t device_stride = src_shape[decomposed_axis_in_src] / src_spec.device_mesh.device_mesh_shape[0]; + if (device_stride >= dst_shape[decomposition_axis_in_dst + 1] && device_stride % dst_shape[decomposition_axis_in_dst + 1] == 0) { + // Since 2nd decomposition dimension is a factor of device stride, + // Sharding happens at 1st decomposition axis in dst. + // device_stride = 10 + // S[0], shape=[20], device=[0, 1] -> S[0]R, shape=[2, 10], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> RS[0], shape=[1, 16], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1] + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + // Sharding spec is copied if the axis is not decomposed. + // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2] + // The spec for "5" is copied. + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // The spec for "5" is copied and "1" is replica. + // This reshape only adds a dummy new axis without affecting + // the underlying sharding status. + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + } else { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + // Now, we know sharding happens at decomposed_axis_in_src axis in destination tensor. + // - effective_device_stride along decomposed_axis_in_src: device_stride / dst_shape[decomposed_axis_in_src + 1] + // - The original device patterns repeats: dst_shape[decomposed_axis_in_src] / effective_device_stride times. + const int64_t effective_device_stride = device_stride / dst_shape[decomposed_axis_in_src + 1]; + // How many times a device ID changes along decomposed_axis_in_src axis in destination tensor. + const int64_t number_of_device_changes = dst_shape[decomposed_axis_in_src] / effective_device_stride; + if ((size_t)number_of_device_changes != src_spec.device_mesh.device_mesh_elements.size()) { + throw std::runtime_error("Not supported. Resharding is required."); + } + auto dst_device_mesh = CreateInterleaveDeviceMesh( + src_spec.device_mesh, 1); + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else if (dst_shape[decomposition_axis_in_dst + 1] > device_stride && dst_shape[decomposition_axis_in_dst + 1] % device_stride == 0) { + // Since 2nd decomposition dimension is a multiple of device stride, + // sharding happens at 2nd decomposition axis in dst. + // stride = 4 + // S[0], shape=[8], device=[0, 1] -> S[0]R, shape=[4, 2], device=[0, 1] + // + // stride = 8 + // S[0], shape=[32], device=[0, 1, 0, 1] -> RS[0], shape=[2, 16], device=[0, 1] + std::vector dst_axis_specs; + // How many times a device ID appears. + // E.g., [0, 1, 0, 1, 0, 1] -> 3 + int64_t repeats = 0; + // Number of unique devices. + // E.g., [0, 1, 0, 1, 0, 1] -> 2 + int64_t repeat_stride = 0; + DeviceMesh dst_device_mesh; + std::tie(repeats, repeat_stride) = ComputeRepeatAndRepeatStride(src_spec.device_mesh.device_mesh_elements); + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (dst_shape[decomposition_axis_in_dst + 1] == 1) { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats == 1 && dst_shape[decomposition_axis_in_dst + 1] == device_stride * repeat_stride) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats != 1 && dst_shape[decomposition_axis_in_dst + 1] % (device_stride * repeat_stride) == 0) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + // Extract [0, 1] from [0, 1, 0, 1]. + std::vector unique_device_mesh_elements( + src_spec.device_mesh.device_mesh_elements.begin(), + src_spec.device_mesh.device_mesh_elements.begin() + repeat_stride); + // Compute new repeats. + // Example of repeats change from 2 to 1: + // [16]-shape tensor [2, 8]-shape tensor + // with 1-D device mesh -> Reshape -> with 1-D device mesh + // [0, 1, 0, 1] (repeats=2) [0, 1] (repeats=1) + const int64_t new_repeat = dst_shape[decomposition_axis_in_dst + 1] / (device_stride * repeat_stride); + dst_device_mesh.device_mesh_shape.push_back(repeat_stride); + dst_device_mesh.device_mesh_elements = RepeatVector(unique_device_mesh_elements, new_repeat); + } else { + throw std::runtime_error("Not supported. Resharding is required."); + } + } + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else { + // Not supported. Resharding is required. + return std::make_tuple(false, TensorPartitionSpec()); + } + } else { + // Source tensor is sharded on non-decomposed axis. + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else { + // R -> RR + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, src_spec.device_mesh)); + } +} + +// Arguments: +// global_data_shape: logical shape of Reshape's 1st input. +// global_shape_span: logical content of Reshape's 2nd input. +// Returns: +// logical shape of Reshape's output. +inline TensorShape InferDistributedReshapeLogicalOutputShape( + const TensorShape& global_data_shape, + const gsl::span& global_shape_span, + const int64_t allow_zero) { + return onnxruntime::cuda::InferReshapeOutputShape( + global_data_shape, + global_shape_span, + allow_zero); +} + +template +DistributedReshape::DistributedReshape(const OpKernelInfo& info) : DistributedKernel(info) { + allow_zero_ = info.GetAttrOrDefault("allowzero", static_cast(0)); +} + +template +Status DistributedReshape::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + auto data_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& data_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + if (data_sharding_spec.HasNoShard() && shape_sharding_spec.HasNoShard() && output_sharding_spec.HasNoShard()) { + // Case: all inputs and outputs are not sharded. + const auto target_shape = onnxruntime::cuda::InferReshapeOutputShape( + data_tensor, + shape_tensor, + allow_zero_); + + auto output_tensor = context->Output(0, target_shape); + + // Copy data from input from output. + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "Shape tensor should not be sharded because it will trigger communication. " + "If sharding shape is needed, please request this feature on Github."); + ORT_ENFORCE(shape_tensor->Shape().NumDimensions() == 1, "Shape must be a 1-D tensor."); + const auto original_data_shape = ComputeOriginShape(data_tensor->Shape(), data_sharding_spec); + const auto original_output_shape = InferDistributedReshapeLogicalOutputShape( + original_data_shape, + shape_tensor->template DataAsSpan(), + allow_zero_); + + // TODO: remove below code after replacing std::vector with TensorShape in other APIs. + std::vector src_shape(original_data_shape.GetDims().begin(), original_data_shape.GetDims().end()); + std::vector dst_shape(original_output_shape.GetDims().begin(), original_output_shape.GetDims().end()); + + // Case: Two axis fusion + bool is_two_axis_fusion = false; + size_t two_axis_fusion_axis_in_dst = 0; + size_t two_axis_fusion_first_fused_axis_in_src = 0; + size_t two_axis_fusion_fused_axis_count = 0; + std::tie( + is_two_axis_fusion, + two_axis_fusion_axis_in_dst, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_fused_axis_count) = IsTwoAxisFusion(src_shape, dst_shape); + + if (is_two_axis_fusion) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisFusion( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + + // Case: Two axis decomposition + bool is_two_axis_decomposition = false; + size_t two_axis_decomposition_decomposed_axis_in_src = 0; + size_t two_axis_decomposition_first_factor_axis_in_dst = 0; + size_t two_axis_decomposition_factor_axis_count_in_dst = 0; + std::tie( + is_two_axis_decomposition, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst, + two_axis_decomposition_factor_axis_count_in_dst) = IsTwoAxisDecomposition(src_shape, dst_shape); + + if (is_two_axis_decomposition) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisDecomposition( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + } + + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h new file mode 100644 index 0000000000000..e251c3cdc38d7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/tensor/reshape.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReshape final : public DistributedKernel { + public: + explicit DistributedReshape(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t allow_zero_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index dfd5f589355df..b6b509023a1a9 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -30,7 +30,7 @@ void GatherTensor( const Tensor* tensor, Tensor* gathered) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); FuncAllGather( nccl_kernel, @@ -51,7 +51,7 @@ std::unique_ptr GatherTensor( const TensorPartitionSpec& spec, const Tensor* tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape gathered_shape(tensor->Shape()); gathered_shape[shard_axis] *= shard_count; @@ -82,7 +82,7 @@ void ShardTensor( const Tensor* tensor, Tensor* shard_tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape shard_shape = ComputeShardShape( tensor->Shape(), shard_axis, @@ -118,7 +118,7 @@ std::unique_ptr ShardTensor( TensorShape shard_shape = ComputeShardShape( tensor->Shape(), spec.GetPartitionAxis(), - spec.GetPartitionCount(spec.GetPartitionAxis())); + spec.GetUniqueDeviceCount(spec.GetPartitionAxis())); auto shard_buffer = Tensor::Create(tensor->DataType(), shard_shape, alloc); // Shard with pre-allocated buffer. diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index 220938f3ceaef..20c936e1b6718 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -129,7 +129,7 @@ TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorParti } TensorShape shape(shard_shape); const int64_t axis = spec.GetPartitionAxis(); - shape[axis] *= spec.GetPartitionCount(axis); + shape[axis] *= spec.GetUniqueDeviceCount(axis); return shape; } @@ -140,7 +140,15 @@ TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpe return shard_shape; } const int64_t axis = spec.GetPartitionAxis(); - shard_shape[axis] /= spec.GetPartitionCount(axis); + const int64_t unique_device_count = spec.GetUniqueDeviceCount(axis); + ORT_ENFORCE(shard_shape[axis] % unique_device_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + // If a [8, 16]-tensor is shared by device mesh [0, 1, 0, 1] along axis=1 (2nd axis), + // the local tensors on device 0 & 1 have same shape [8, 8 (from 16/2)] instead of + // [8, 4 (from 16/4)]. The reason is that + // - First, the original tensor are split into 4 sub-tensors [8, 4] along the 2nd axis. + // - The 1st and 3rd sub-tensors are concatenated along axis=1 to one tensor on device 0. + // - The 2nd and 4th sub-tensors are concatenated along axis=1 to one tensor on device 1. + shard_shape[axis] /= unique_device_count; return shard_shape; } @@ -202,7 +210,7 @@ bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec) { if (axis < 0 || gsl::narrow(axis) >= shape.NumDimensions()) { return false; } - if (shape[axis] % spec.GetPartitionCount(axis) != 0) { + if (shape[axis] % spec.GetDeviceCount(axis) != 0) { return false; } return true; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 451d44b4bd434..5185c41e6888c 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -76,6 +76,43 @@ class DeviceMesh { void Print() const { std::cout << ToString() << std::endl; } + + static DeviceMesh Create1D(std::vector device_mesh_elements, size_t repeats = 1) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape.push_back(device_mesh_elements.size() * repeats); + for (size_t i = 0; i < repeats; ++i) { + device_mesh.device_mesh_elements.insert( + device_mesh.device_mesh_elements.end(), + device_mesh_elements.begin(), + device_mesh_elements.end()); + } + return device_mesh; + } + + // If the two meshes have the same shape and elements, return true. + // Otherwise, return false. + bool operator==(const DeviceMesh& other) const { + if (device_mesh_shape.size() != other.device_mesh_shape.size() || + device_mesh_elements.size() != other.device_mesh_elements.size()) { + return false; + } + + for (size_t i = 0; i < device_mesh_elements.size(); ++i) { + if (device_mesh_elements.at(i) != other.device_mesh_elements.at(i)) { + return false; + } + } + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (device_mesh_shape.at(i) != other.device_mesh_shape.at(i)) { + return false; + } + } + return true; + } + + bool operator!=(const DeviceMesh& other) const { + return !(*this == other); + } }; class AxisPartitionSpec { @@ -114,10 +151,14 @@ class AxisPartitionSpec { return AxisPartitionSpec(Condition::Shard, device_mesh_axis); } + static AxisPartitionSpec CreateCopy(const AxisPartitionSpec& spec) { + return AxisPartitionSpec(spec.cond, spec.device_mesh_axis); + } + // A normal ctor. // TODO(wechi): Consider to hide it and revise the `public` members/functions // exposed to the user. - AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : device_mesh_axis(device_mesh_axis_), cond(cond_) {} + AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : cond(cond_), device_mesh_axis(device_mesh_axis_) {} // Helper to debug and generate error message; e.g., // "RS[0]". @@ -132,6 +173,14 @@ class AxisPartitionSpec { void Print() const { std::cout << ToString() << std::endl; } + + bool operator==(const AxisPartitionSpec& other) const { + return cond == other.cond && device_mesh_axis == other.device_mesh_axis; + } + + bool operator!=(const AxisPartitionSpec& other) const { + return !(*this == other); + } }; // Return true if `axis` is a valid axis index for a tensor of rank `rank`. @@ -193,6 +242,32 @@ class TensorPartitionSpec { // const TensorPartitionSpec& spec, int64_t new_shard_axis) { // } + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const size_t rank, const DeviceMesh& device_mesh) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateOneTensorAxisOneDeviceMeshAxisSharding( + const size_t rank, const DeviceMesh& device_mesh, const size_t tensor_axis, const size_t device_mesh_axis) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + axis_specs[tensor_axis] = AxisPartitionSpec::CreateShard(device_mesh_axis); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateByDropOneAxis( + const TensorPartitionSpec& TensorPartitionSpec, const size_t axis_to_drop) { + std::vector axis_specs; + for (size_t i = 0; i < TensorPartitionSpec.axis_specs.size(); ++i) { + if (i != axis_to_drop) { + axis_specs.push_back(TensorPartitionSpec.axis_specs[i]); + } + } + return TensorPartitionSpec::Create(axis_specs, TensorPartitionSpec.device_mesh); + } + // Helper to debug and generate error message; e.g., // "TensorPartitionSpec{RS[0], Device Mesh: DeviceMesh{Shape: [4,], Elements: [0,1,2,3,]}}". std::string ToString() const { @@ -303,7 +378,7 @@ class TensorPartitionSpec { // Return the number of shards along the first sharded tensor axis. // This value matches the number of devices along the associated mesh axis. // Return 1 if there is no sharding. - int64_t GetPartitionCount(int64_t axis) const { + int64_t GetDeviceCount(int64_t axis) const { ValidateAxisIndex(axis, Rank()); auto axis_spec = GetAxisSpec(axis); if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { @@ -312,6 +387,37 @@ class TensorPartitionSpec { return device_mesh.device_mesh_shape.at(axis_spec.device_mesh_axis); } } + + // Similar to GetDeviceCount(), but returns the number of unique devices + // along the first sharded tensor axis. + int64_t GetUniqueDeviceCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + std::set device_ids( + device_mesh.device_mesh_elements.begin(), + device_mesh.device_mesh_elements.end()); + return device_ids.size(); + } + } + + bool operator==(const TensorPartitionSpec& other) const { + if (axis_specs.size() != other.axis_specs.size()) { + return false; + } + for (size_t i = 0; i < axis_specs.size(); ++i) { + if (!(axis_specs.at(i) == other.axis_specs.at(i))) { + return false; + } + } + return device_mesh == other.device_mesh; + } + + bool operator!=(const TensorPartitionSpec& other) const { + return !(*this == other); + } }; // Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index e762a80cb0e2f..d51915b85095f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -97,6 +97,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -144,6 +145,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -165,6 +167,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); #endif template <> @@ -260,6 +270,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -313,6 +324,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, @@ -334,6 +346,14 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..87e88ac31c998 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4be77..48b161552ce0c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050ee6..9532aeecb2f57 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc new file mode 100644 index 0000000000000..251850f621361 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/math/gemm.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cuda/math/gemm_float8.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", BuildKernelDefConstraints()) \ + .TypeConstraint("TB", BuildKernelDefConstraints()) \ + .TypeConstraint("TR", BuildKernelDefConstraints()) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ + GemmFloat8); + +REGISTER_KERNEL() + +GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto& device_prop = GetDeviceProp(); + sm_count_ = device_prop.multiProcessorCount; + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + +#if (CUDA_VERSION <= 12000) + ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); +#endif + + std::string stemp = info.GetAttrOrDefault("activation", "NONE"); + if (stemp == "NONE") { + epilogue_ = CUBLASLT_EPILOGUE_DEFAULT; + } else if (stemp == "RELU") { + epilogue_ = CUBLASLT_EPILOGUE_RELU; + } else if (stemp == "GELU") { + epilogue_ = CUBLASLT_EPILOGUE_GELU; + } else { + ORT_THROW("Unexpected value for activation: '", stemp, "'."); + } +} + +Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const { + GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({})); + if (!helper.State().IsOK()) + return helper.State(); + + M = gsl::narrow_cast(helper.M()); + N = gsl::narrow_cast(helper.N()); + K = gsl::narrow_cast(helper.K()); + return helper.State(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu new file mode 100644 index 0000000000000..df25342342cd5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -0,0 +1,402 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// The operator calls function 'cublasLtMatmul' +// (https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul). +// It lets the function checks what configuration is valid or not. If not, the error message +// shows the error message 'CUBLAS_STATUS_NOT_SUPPORTED'. NVIDIA documentation provides +// information on what attribute or type must be modified. +// This operator requires CUDA_VERSION >= 11.8 for float 8 and CUDA_VERSION >= 12.0 +// for beta != 0. + +#include +#include +#include +#include "contrib_ops/cuda/math/gemm_float8.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// It must exist somewhere already. +int32_t TypeSize(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return 4; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return 2; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return 1; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + +void GemmFloat8::SetParams(const TensorShape& a_shape, const TensorShape& b_shape, + int& M, int& N, int& K, int& lda, int& ldb, int& ldd) const { + int m_idx = transA_ ? 1 : 0; + int k_idx = 1 - m_idx; + int n_idx = transB_ ? 0 : 1; + + M = static_cast(a_shape[m_idx]); + K = static_cast(a_shape[k_idx]); + N = static_cast(b_shape[n_idx]); + lda = static_cast(a_shape[1]); + ldb = static_cast(b_shape[1]); + ldd = static_cast(b_shape[n_idx]); +} + +template +int32_t GetTypeAndShape(const TValue* input, + TensorShape& shape, + bool swap = false) { + shape = input->Shape(); + ORT_ENFORCE(shape.NumDimensions() == 2); + if (swap) { + std::swap(shape[0], shape[1]); + } + return input->GetElementType(); +} + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input_A = nullptr; + const Tensor* input_B = nullptr; + const Tensor* input_C = nullptr; + const Tensor* scale_A = nullptr; + const Tensor* scale_B = nullptr; + const Tensor* scale_Y = nullptr; + bool has_scales = false; + bool has_bias = false; + int n_inputs = ctx->InputCount(); + + input_A = ctx->Input(0); + input_B = ctx->Input(1); + if (n_inputs == 3) { + input_C = ctx->Input(2); + has_bias = true; + } else if (n_inputs > 3) { + ORT_ENFORCE(n_inputs >= 5, "Unexpected number of inputs=", n_inputs, "."); + has_scales = true; + scale_A = ctx->Input(3); + scale_B = ctx->Input(4); + scale_Y = n_inputs < 6 ? nullptr : ctx->Input(5); + ORT_ENFORCE(scale_A->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_B->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_Y == nullptr || scale_Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + if (ctx->Input(2) != nullptr) { + input_C = ctx->Input(2); + has_bias = true; + ORT_ENFORCE(input_C->GetElementType() == dtype_, "Bias type must be equal to dtype."); + } + } + + auto first_type = input_A->GetElementType(); + bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + if (!is_float8) + return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); + return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); +} + +Status GemmFloat8::ComputeRowMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_A, dtype_B, dtype_C, + dtype_Y, shape_A, shape_B, shape_C, shape_Y, transA_, transB_, + input_A->DataRaw(), input_B->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), M, N, K, lda, ldb, ldd, true); +} + +Status GemmFloat8::ComputeColMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + std::swap(shape_A[0], shape_A[1]); + std::swap(shape_B[0], shape_B[1]); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C, true) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_B, dtype_A, dtype_C, + dtype_Y, shape_B, shape_A, shape_C, shape_Y, transB_, transA_, + input_B->DataRaw(), input_A->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), N, M, K, ldb, lda, ldd, false); +} + +Status GemmFloat8::ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const { + cudaStream_t stream = Stream(ctx); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + cublasLtHandle_t cublasLt; + CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt)); + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, + Ddesc = nullptr; + + // Create matrix descriptors. Not setting any extra attributes. + cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A); + cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B); + cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y); + cudaDataType_t scale_cuda_type = + onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C); + + cublasComputeType_t compute_type; + switch (d_cuda_type) { + case CUDA_R_16F: + switch (a_cuda_type) { + case CUDA_R_8F_E4M3: + case CUDA_R_8F_E5M2: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + break; + } + break; + case CUDA_R_16BF: + compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; + break; + case CUDA_R_32F: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + ORT_THROW("Unable to determine computeType in operator GemmFloat8."); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb)); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd)); + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + CUBLAS_RETURN_IF_ERROR( + cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type)); + cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); + + if (sm_count_ != 0) { + int math_sm_count = static_cast(sm_count_); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, + sizeof(math_sm_count))); + } + + if (has_scales) { + // gemm float 8 + const int8_t ifast_accumulation_mode = 1; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &ifast_accumulation_mode, sizeof(ifast_accumulation_mode))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a, + sizeof(p_scale_a))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b, + sizeof(p_scale_b))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, + sizeof(p_scale_b))); + + // float 8 +#if CUDA_VERSION >= 11080 + if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { + // For FP8 output, cuBLAS requires C_type to be same as bias_type + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type, + sizeof(bias_cuda_type))); + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } +#else + // An output is still needed but it is not initialized. + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); +#endif + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue_, sizeof(epilogue_)); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t + // The workspace should be allocated once from OpKernelContext assuming + // only one cuda function is running at a time (which is not necessarily true + // with H100). + size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb + cublasLtMatmulPreference_t preference = nullptr; + cublasLtMatmulPreferenceCreate(&preference); + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, sizeof(workspaceSize)); + + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResults = 0; + cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic( + cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults); + ORT_ENFORCE( + returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to find any suitable algorithm due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, + ", alpha=", alpha_, ", beta=", beta_, ", n_inputs=", n_inputs, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", shape_C=", (shape_C.NumDimensions() > 0 ? shape_C[0] : 0), "x", + (shape_C.NumDimensions() > 1 ? shape_C[1] : 0), ", M=", M, ", N=", N, ", K=", K, + ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, + ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". Check NVIDIA documentation to see what combination is valid: ", + "https://docs.nvidia.com/cuda/cublas/" + "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" + "cublasltmatmulalgogetheuristic."); + + void* workspace = nullptr; + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize)); + } + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul + const void* bias = has_bias ? p_input_c : p_output_y; + cuda_status = cublasLtMatmul( + cublasLt, operationDesc, static_cast(&alpha_), /* alpha */ + p_input_a, /* A */ + Adesc, p_input_b, /* B */ + Bdesc, static_cast(&beta_), /* beta */ + bias, /* C */ + Cdesc, p_output_y, /* Y */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream); /* stream */ + ORT_ENFORCE( + cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to run cublasLtMatmul due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, ", alpha=", alpha_, + ", n_inputs=", n_inputs, ", A_type=", + onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, + ", ldd=", ldd, ", workspaceSize=", workspaceSize, + ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaFree(workspace)); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt)); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.h b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h new file mode 100644 index 0000000000000..e84ccd55b2003 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cublas_v2.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Calls https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul. +// D = alpha*(A*B) +class GemmFloat8 final : public onnxruntime::cuda::CudaKernel { + public: + GemmFloat8(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + void SetParams(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K, + int& lda, int& ldb, int& ldd) const; + Status SetCheck(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K) const; + + Status ComputeRowMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + Status ComputeColMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + + Status ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_b, + int32_t dtype_c, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool transa, bool transb, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const; + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t sm_count_; + int64_t dtype_; + cublasLtEpilogue_t epilogue_; + + // TODO(xadupre): add epilogue (= activation function, Relu or Gelu are available). +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 246b66078537a..78983ac95e672 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { Nop{}); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); if constexpr (USE_MASK) { ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index cbf24ee2f5487..ea9040aa7875f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -58,7 +58,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); auto nop = Nop{}; auto addfastgelu = AddFastGelu{}; @@ -67,7 +67,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { params->lda, params->ldb, std::array{0}, params->ldc, nop, nop, addfastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -95,7 +95,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); auto nop = Nop{}; auto fastgelu = FastGelu{}; @@ -108,7 +108,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { params->ldc, nop, nop, fastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af36c..e82e15a304f4c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index e87813fb19956..0146e81c6cf8c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -79,7 +79,7 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { nullptr, activation); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 3d971e6aa29a2..ef68b88187e08 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -9,6 +9,7 @@ #include "onnx/defs/data_type_utils.h" #include "core/framework/op_kernel.h" +#include "core/framework/utils.h" using namespace ONNX_NAMESPACE::Utils; @@ -77,7 +78,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe ORT_THROW_IF_ERROR(node->ForEachWithIndex( node->OutputDefs(), [&](const NodeArg& node_arg, size_t out_index) { - if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) { + if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) { cpu_output_args.insert(&node_arg); auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 96b4cc53a022c..6d2dd641f6bc6 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -232,14 +232,15 @@ class TunableOp { return timer.Duration() / num_iter; } - static bool IsSupported(Op& op, const ParamsT* param) { - Status status = op.IsSupported(param); + // Filter all Status, only OK and TUNABLE_OP_UNSUPPORTED is left, other error status will be thrown, and to be + // processed by onnxruntime. We return Status to avoid the construction of op and params signature string. + static Status IsSupported(Op& op, const ParamsT* params) { + Status status = op.IsSupported(params); if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) { - LOGS_DEFAULT(VERBOSE) << "unsupported reason: " << status.ErrorMessage(); - return false; + return status; } ORT_THROW_IF_ERROR(status); - return true; + return status; } protected: @@ -250,9 +251,9 @@ class TunableOp { int FindFastestImpl(const ParamsT* params, const std::vector>& candidates) { ITuningContext* ctx = params->TuningContext(); auto op_sig = Signature(); - auto param_sig = params->Signature(); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')'; - auto min_time = std::numeric_limits::infinity(); + auto params_sig = params->Signature(); + LOGS_DEFAULT(VERBOSE) << "finding fastest for " << op_sig << '(' << params_sig << ')'; + auto min_duration_ms = std::numeric_limits::infinity(); int id = -1; constexpr const int max_tuning_iter = 100; @@ -260,30 +261,32 @@ class TunableOp { for (size_t i = 0; i < candidates.size(); i++) { auto& candidate = const_cast&>(candidates[i]); - if (!IsSupported(candidate, params)) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i; + auto status = IsSupported(candidate, params); + if (!status.IsOK()) { + LOGS_DEFAULT(VERBOSE) << "├──unsupported id=" << i << ", " << op_sig << '(' << params_sig << ")"; + LOGS_DEFAULT(VERBOSE) << "│ reason: " << status.ErrorMessage(); continue; } WarmUp(candidate, params); auto approx_duration = Profile(candidate, params, approx_num_iter); - if (approx_duration > 2 * min_time) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl skip slow instance " << op_sig << '(' << param_sig << ") id=" << i; + if (approx_duration > 2 * min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──skip slow instance id=" << i; continue; } int tuning_iter = std::max(1, int(std::min(double(max_tuning_iter), ctx->GetMaxTuningDurationMs() / approx_duration))); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl run instance " << op_sig << '(' << param_sig << ") id=" << i << " " << tuning_iter << " times."; - - auto time = Profile(candidate, params, tuning_iter); - if (time < min_time) { - min_time = time; + auto duration_ms = Profile(candidate, params, tuning_iter); + if (duration_ms < min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──found better instance, new best id=" << i << ", old id=" << id << ". " + << duration_ms << "ms, " << tuning_iter << " iters."; + min_duration_ms = duration_ms; id = static_cast(i); } } ORT_ENFORCE(id >= 0, "Could not find viable op"); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ") found fastest with id=" << id; + LOGS_DEFAULT(VERBOSE) << "└──found fastest with id=" << id << " for " << op_sig << '(' << params_sig << ")"; std::this_thread::sleep_for(std::chrono::milliseconds(50)); return id; } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index d63881ab4ff04..23fe5e1cd3d96 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index); + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + } +#else + ORT_UNUSED_PARAMETER(node); +#endif + + return false; +} + +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) { + if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) { + return true; + } + +#ifdef ENABLE_ATEN + if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && + node.Domain() == kPytorchAtenDomain) { + const auto& attrs = node.GetAttributes(); + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); + std::string op_name = attrs.at("operator").s(); + std::string overload_name = ""; + if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) { + overload_name = attrs.at("overload_name").s(); + } + + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index ea6a629f87cb8..f0b1b9109d405 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet bool sync_subgraph_fetches = false); bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); template constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 97befe2a58301..070df487a264d 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -191,6 +191,88 @@ void RegisterCollectiveOps() { .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReshape) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr( + "allowzero", + "(Optional) By default, when any value in the 'shape' input is equal to zero " + "the corresponding dimension value is copied from the input tensor dynamically. " + "allowzero=1 indicates that if any value in the 'shape' input is set to zero, " + "the zero value is honored, similar to NumPy.", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "Specified shape for output.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 681a728f823da..e757e39130d39 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2573,6 +2573,124 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_TYPES \ + { "tensor(float8e4m3fn)", "tensor(float8e5m2)", "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#else +#define GEMM_FLOAT8_TYPES \ + { "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#endif + +ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, + OpSchema() + .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC") + .Attr( + "transA", + "Whether A should be transposed. Float 8 only supprted transA=0.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "transB", + "Whether B should be transposed. Float 8 only supprted transB=1.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "alpha", + "Scalar multiplier for the product of input tensors A * B.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "beta", + "Scalar multiplier for the product of input bias C.", + AttributeProto::FLOAT, + 0.0f) + .Attr( + "dtype", + "Output Type. Same definition as attribute 'to' for operator Cast.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "activation", + "Activation function, RELU or GELU or NONE (default).", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Input( + 0, + "A", + "Input tensor A. " + "The shape of A should be (M, K) if transA is 0, " + "or (K, M) if transA is non-zero.", + "TA") + .Input( + 1, + "B", + "Input tensor B. " + "The shape of B should be (K, N) if transB is 0, " + "or (N, K) if transB is non-zero.", + "TB") + .Input( + 2, + "C", + "Input tensor C.", + "TC", + OpSchema::Optional) + .Input( + 3, + "scaleA", + "Scale of tensor A if A is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 4, + "scaleB", + "Scale of tensor B if B is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 5, + "scaleY", + "Scale of the output tensor if A or B is float 8.", + "TS", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (M, N).", "TR") + .TypeConstraint( + "TA", + GEMM_FLOAT8_TYPES, + "Constrain type to input A.") + .TypeConstraint( + "TB", + GEMM_FLOAT8_TYPES, + "Constrain type to input B.") + .TypeConstraint( + "TC", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain type to input C.") + .TypeConstraint( + "TR", + GEMM_FLOAT8_TYPES, + "Constrain type to result type.") + .TypeConstraint("TS", {"tensor(float)"}, + "Constrain type for all input scales (scaleA, scaleB, scaleY).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto transAAttr = ctx.getAttribute("transA"); + bool transA = transAAttr ? static_cast(transAAttr->i()) != 0 : false; + auto transBAttr = ctx.getAttribute("transB"); + bool transB = transBAttr ? static_cast(transBAttr->i()) != 0 : false; + auto& first_input_shape = getInputShape(ctx, 0); + auto& second_input_shape = getInputShape(ctx, 1); + if (first_input_shape.dim_size() != 2) { + fail_shape_inference("First input does not have rank 2"); + } + if (second_input_shape.dim_size() != 2) { + fail_shape_inference("Second input does not have rank 2"); + } + updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); + })); + static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, int64_t N) { diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa6149b..f81c3b8e0182c 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index afaa380d6ac79..b35cfc5d12f36 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); @@ -112,6 +113,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFloat8); class OpSet_Microsoft_ver1 { public: @@ -204,6 +206,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -218,6 +221,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; } // namespace contrib diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 65b48a3009e72..f3bc2a2434ab3 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -229,3 +229,117 @@ MlasQ8Q4GemmBatch( const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool ); + + +//////////////////////////////////////////////////////////// +// Blockwise quantization and dequantization where quantization +// parameters are packed into separate buffers. +// + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantization parameter matrix [meta_rows, meta_cols] +*/ +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantized matrix [q_rows, q_cols]. The quantized matrix + * is in column major layout, with bits packed on the column. + * + * @tparam T + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols +*/ +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + + +/** + * @brief Blockwise 4 bits quantization, resulting elements and quantization + * parameters (scales, zero points) are packed into separate matrices + * all in column major layout for faster access during subsequent matrix + * multiplication. + * + * @tparam ElementT type of the input matrix element, usually floating point + * + * @param dst points to the quantized matrix, shape [rows, columns] column major + * @param scales points to the scales matrix, column major + * @param zero_points points to the zero_points matrix, column major + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool +*/ +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +/** + * @brief Blockwise 4 bits dequantization, quantized elements and quantization + * parameters (scales, zero points) are from separate matrices packed + * in column major layout. Output is a floating point matrix in column + * major layout for faster access during subsequent matrix multiplication. + * + * @tparam ElementT type of the dequantized matrix element, usually floating point + * @param dst points to dequantized matrix shape [rows, columns] column major + * @param src points to quantized matrix, column major + * @param scales points to quantization scales, column major + * @param zero_points points to quantization zero points, column major + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param thread_pool +*/ +template +void +MlasDequantizeBlockwise( + ElementT* dst, + const uint8_t* src, + const ElementT* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 2fdb0dda5d25c..e0c2772cbb719 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1077,6 +1077,23 @@ MlasTrySimpleParallel( const std::function& Work ); + +/** + * @brief Distribute many iterations of work over a thread pool if supported. + * This function is for small workloads in non-performance critical situation. + * + * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP + * @param Iterations [IN] Total number of iterations + * @param Work [IN] Logic for computing a range of iterations [begin, end) + */ +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work + ); + + inline ptrdiff_t MlasGetMaximumThreadCount( diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 85c0d13006126..24a2212ba0714 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -294,3 +294,649 @@ MlasQ4GemmUnPackB( return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); } } + + + +/*************************************************************** + * The quantization format that pack data and quantization + * parameters into separate buffers. + */ + + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + + +template +struct BitsTraits { + static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); + + static constexpr int kBits = qbits; + static constexpr int kMax = (1 << qbits) - 1; + static constexpr int kMid = 1 << (qbits - 1); + static constexpr float kMaxFp = static_cast(kMax); + + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); +}; + + +/** + * @brief Rectify min/max from a set of weights, and convert to scale and zero point + * for quantization + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @param[in] min + * @param[in] max + * @param[out] scale + * @param[out] zp + */ +template +MLAS_FORCEINLINE +void +range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) +{ + constexpr int zp_max = BitsTraits::kMax; + constexpr float zp_max_fp = BitsTraits::kMaxFp; + + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + float scale_f = (max - min) / zp_max; + + float zero_point_fp = min; + if (scale_f != 0.0f) { + zero_point_fp = 0.f - min / scale_f; + } + + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > zp_max_fp) { + zp = zp_max; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + scale = static_cast(scale_f); +} + +template +MLAS_FORCEINLINE +void +range2scale(float min, float max, ScaleT& scale) +{ + constexpr int mid_v = BitsTraits::kMid; + constexpr float mid_fp = static_cast(-mid_v); + + max = fabsf(max) > fabsf(min) ? max : min; + + scale = static_cast(max / mid_fp); +}; + + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlockwiseQuantizer { + // To support other qbits, need to add bit packing code for + // storing to dst and zero points + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + + static + MLAS_FORCEINLINE + void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) + { + meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + } + + static + MLAS_FORCEINLINE + void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) { + int meta_rows; + int meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + + // quantized matrix is stored in column major, packed by column + q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + q_cols = meta_cols * QuantBlk::kColumn; + } + + /** + * @brief Quantized a Matrix shape [rows, columns], resulting quantized + * and packed data are stored in column major (transposed) + * @param[out] dst pointer to the quantized weights, column major: [columns, rows] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[out] zero_points pointer to the zero points, same shape as scale + * @param[in] src pointer to the source matrix, row major: [rows, columns] + * @param rows + * @param columns + * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next + */ + static void quantizeAndTranspose( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int32_t rows, + int32_t columns, + int32_t leadingDimension, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + // compute scale and zero point + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + + // scan a single block to extract range [min, max] + float min = std::numeric_limits::max(); + float max = -min; + const int row_start = r + kpack * QuantBlk::kRow; + const int row_end = std::min(row_start + QuantBlk::kRow, r_end); + for (int i = row_start; i < row_end; ++i) { + for (int j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) min = v; + if (v > max) max = v; + } + } + + // store scale and zero point at quant parameter matrix position + if (row_start < row_end) { + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + } + } + } + + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); + } + + // !! 4b specific code + dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } + } + }); + } + + /** + * @brief Dequantize a column major quantized matrix, and store the result in a column major + * matrix for use in GEMM + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] weights pointer to the quantized weights, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] rows + * @param[in] columns + */ + static void dequantize( + ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; + + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = + static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } + } + }); + } +}; + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape( + rows, columns, meta_rows, meta_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + default: + meta_rows = 0; + meta_cols = 0; + break; + } +} + + + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape( + rows, columns, q_rows, q_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + default: + q_rows = 0; + q_cols = 0; + break; + } +} + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + T* scales, + uint8_t* zero_points, + const T* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +template +void +MlasDequantizeBlockwise( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 32: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 64: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 128: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + case 256: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index ecdc5250ebf0e..dc5daf998d3be 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -93,3 +93,41 @@ MlasTrySimpleParallel( MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); #endif } + + +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work) +{ + // + // Execute the routine directly if only one iteration is specified. + // + if (Iterations == 1) { + Work(0); + return; + } + +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // + // Fallback to OpenMP or a serialized implementation. + // + + // + // Execute the routine for the specified number of iterations. + // + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { + Work(tid); + } +#else + // + // Schedule the threaded iterations using the thread pool object. + // + + MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); +#endif + +} \ No newline at end of file diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7c087ec77d9fe..959fcd6efdc3c 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -32,7 +32,7 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, int64_t to_type, onnxruntime::ProviderType providerType) { // insert cast op to cast input - std::string node_name = graph.GenerateNodeName("InsertedCast_" + old_arg->Name()); + std::string node_name = graph.GenerateNodeName("InsertedPrecisionFreeCast_" + old_arg->Name()); auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); @@ -235,7 +235,8 @@ enum TypeGroup { Unknown = -1, Bool = 0, Integer = 1, - Float = 2, + Unsigned = 2, + Float = 3, }; TypeGroup GetTypeGroup(DataType type) { @@ -243,11 +244,14 @@ TypeGroup GetTypeGroup(DataType type) { return Bool; } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { return Integer; } + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { return Float; } @@ -255,6 +259,22 @@ TypeGroup GetTypeGroup(DataType type) { return Unknown; } +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -262,6 +282,48 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. + + // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. + // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + return true; + } + + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { auto output_args = graph.GetOutputs(); InlinedHashSet graph_outputs; @@ -293,17 +355,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // - for each consumer cast node, it meets above condition for this optimization. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - TypeGroup src_type_group = GetTypeGroup(src_type); - TypeGroup dst_type_group = GetTypeGroup(dst_type); - if (src_type_group == Unknown || dst_type_group == Unknown) { - continue; - } - - bool loss_precision_cast = false; - if (src_type_group > dst_type_group) { - loss_precision_cast = true; - } + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); size_t num_children = node.GetOutputEdgesCount(); bool inconsistent_casts = false; @@ -312,10 +365,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - TypeGroup src_type_group1 = GetTypeGroup(src_type1); - TypeGroup dst_type_group1 = GetTypeGroup(dst_type1); - if (src_type_group1 == Unknown || dst_type_group1 == Unknown || - (loss_precision_cast && dst_type_group1 > src_type_group1)) { + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; break; } diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index ed3e35706b688..0d7ab70eba613 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg if (!arg->Exists()) continue; - if (kci && kci->kernel_def->IsOutputOnCpu(i)) + if (utils::IsOutputOnCpu(node, kci, i)) non_provider_output_defs_.insert(arg); else provider_output_defs_.insert(arg); @@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } if (arg_output_index != -1) { - if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it); + if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it); } } } @@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker // normally initializers are only inputs, but things may change with ops like assign ORT_THROW_IF_ERROR(Node::ForEachWithIndex( p_node->OutputDefs(), - [kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { - if (kci->kernel_def->IsOutputOnCpu(index)) { + [kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { + if (utils::IsOutputOnCpu(*p_node, kci, index)) { ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end()); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 57477f167c555..288ca8e97e34d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -27,5 +27,90 @@ const HalfGemmOptions* HalfGemmOptions::GetInstance() { return &instance; } +const char* cublasGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + default: + return ""; + } +} + +const char* CudaDataTypeToString(cudaDataType_t dt) { + switch (dt) { + case CUDA_R_16F: + return "CUDA_R_16F"; + case CUDA_R_16BF: + return "CUDA_R_16BF"; + case CUDA_R_32F: + return "CUDA_R_32F"; +#if (CUDA_VERSION >= 11080) + case CUDA_R_8F_E4M3: + return "CUDA_R_8F_E4M3"; + case CUDA_R_8F_E5M2: + return "CUDA_R_8F_E5M2"; +#endif + default: + return ""; + } +} + +const char* CublasComputeTypeToString(cublasComputeType_t ct) { + switch (ct) { + case CUBLAS_COMPUTE_16F: + return "CUBLAS_COMPUTE_16F"; + case CUBLAS_COMPUTE_32F: + return "CUBLAS_COMPUTE_32F"; + case CUBLAS_COMPUTE_32F_FAST_16F: + return "CUBLAS_COMPUTE_32F_FAST_16F"; + case CUBLAS_COMPUTE_32F_FAST_16BF: + return "CUBLAS_COMPUTE_32F_FAST_16BF"; + case CUBLAS_COMPUTE_32F_FAST_TF32: + return "CUBLAS_COMPUTE_32F_FAST_TF32"; + case CUBLAS_COMPUTE_64F: + return "CUBLAS_COMPUTE_64F"; + default: + return ""; + } +} + +// It must exist somewhere already. +cudaDataType_t ToCudaDataType(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return CUDA_R_32F; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return CUDA_R_16F; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return CUDA_R_16BF; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + return CUDA_R_8F_E4M3; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return CUDA_R_8F_E5M2; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index fa258961f1155..9cd4e721ccab8 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -11,6 +11,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/common/status.h" +#include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" @@ -48,6 +49,33 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef BFloat16 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E4M3FN MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E5M2 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) @@ -152,5 +180,13 @@ class HalfGemmOptions { static HalfGemmOptions instance; }; +const char* cublasGetErrorEnum(cublasStatus_t error); + +const char* CudaDataTypeToString(cudaDataType_t dt); + +const char* CublasComputeTypeToString(cublasComputeType_t ct); + +cudaDataType_t ToCudaDataType(int32_t element_type); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 93e18d2940fc2..4f5469ad8de36 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -645,51 +645,54 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); @@ -824,12 +827,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -843,45 +840,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); @@ -959,22 +917,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMin); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); @@ -1128,50 +1070,36 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); @@ -1270,13 +1198,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1329,6 +1257,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 18 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); + // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Cast); @@ -1594,51 +1528,51 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1777,12 +1711,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1796,45 +1727,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1908,22 +1800,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2077,50 +1953,36 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2219,13 +2081,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2277,6 +2139,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index f8b92eface52f..e3106e41e77c8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -176,17 +176,17 @@ class CudaKernel : public OpKernel { return provider_->ComputeStream(); } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); + return gpu_data_transfer->CopyTensorAsync(src, dst, stream); + } + protected: template inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { return provider_->template GetConstOnes(count, stream); } - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 2f057d53d5607..d46ed9c245a8e 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,140 +16,29 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -// opset 11 explicitly added support for negative axis. implementation already allowed it. -#define REGISTER_KERNEL_TYPED(name, T) \ +#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 12, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register those with changes in OpSet12. -#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register ReduceMin int64_t support in OpSet14. -#define REGISTER_KERNEL_TYPED_14(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 14, \ + 1, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet -#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + version, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); -// Register with the latest version 13 -#define REGISTER_KERNEL_TYPED_13(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ + REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ + REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -917,69 +806,76 @@ template std::unique_ptr ReduceComputeShape(); + +#ifdef ENABLE_STRIDED_TENSORS + // Strided output. + if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) { + gsl::span input_strides = input_data_tensor->Strides(); + TensorShapeVector output_strides = + ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape); + output_tensor->SetShapeAndStrides(output_shape, output_strides); + return Status::OK(); + } +#endif + + auto output_dims = output_shape.AsShapeVector(); + auto input_dims = input_data_tensor->Shape().AsShapeVector(); + + CalcEffectiveDims(input_dims, output_dims); + int rank = gsl::narrow_cast(output_dims.size()); + + TensorPitches original_input_strides(input_dims); + TensorPitches original_output_strides(output_dims); + + TArray input_strides(rank); + for (auto i = 0; i < rank; i++) { + input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i]; + } + + TArray output_strides(rank); + for (auto i = 0; i < rank; i++) { + output_strides[i] = fast_divmod(static_cast(original_output_strides[i])); + } + + return ExpandImpl( + cuda_kernel->Stream(ctx), + input_data_tensor->DataType()->Size(), + gsl::narrow_cast(output_shape.Size()), + gsl::narrow_cast(input_data_tensor->Shape().Size()), + input_data_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + output_strides, + input_strides); +} + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor) { + // new shape to be expanded to + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + + ORT_ENFORCE( + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape) + .IsOK()); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc); + + // Only assign output values when output tensor is non-empty + // because empty tensor doesn't own any data. + if (output_shape.Size() > 0) { + ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK()); + } + + return output_tensor; +} + #ifdef ENABLE_STRIDED_TENSORS #define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0) #else diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index 4cf4c14e61058..a0b12790017f6 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -20,5 +20,18 @@ Status ComputeOutputShape( const TensorShape& rhs_shape, TensorShape& out_shape); +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor); + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 3c6d900cee9a4..ab364c274a32d 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -6,6 +6,81 @@ namespace onnxruntime { namespace cuda { +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, // Data tensor's shape. + const gsl::span& shape_span, // Shape that data tensor reshape to. + bool allow_zero) { + TensorShapeVector shape_vector(shape_span.begin(), shape_span.end()); + ReshapeHelper helper(data_tensor_shape, shape_vector, allow_zero); + return TensorShape(shape_vector); +} + +TensorShape InferReshapeOutputShape(const Tensor* src, const Tensor* shape, bool allow_zero) { + ORT_ENFORCE(shape != nullptr, "Cannot reshape to a null shape."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "Shape must be an 1-D tensor."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape must be on CPU."); + + return InferReshapeOutputShape( + src->Shape(), + shape->template DataAsSpan(), + allow_zero); +} + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y) { + if (!X) return Status(common::ONNXRUNTIME, common::FAIL, "Missing data tensor to be reshaped."); + if (!shape) return Status(common::ONNXRUNTIME, common::FAIL, "Missing shape tensor for reshaping."); + if (shape->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + } + if (shape->Location().device.Type() != OrtDevice::CPU) { + return Status(common::ONNXRUNTIME, common::FAIL, "Shape tensor must be on CPU."); + } + + const void* src_data = X->DataRaw(); + void* dst_data = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (src_data != dst_data) { + ORT_ENFORCE(ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(cuda_kernel->CopyTensor(*X, *Y, *ctx->GetComputeStream())); + } + + return Status::OK(); +} + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero) { + // TODO(wechi): Study if Tensor can be created as view to existing tensor. + // This feature can refine code for re-sharding and shape broadcasting. + + ORT_ENFORCE(X != nullptr, "Missing data tensor to be reshaped."); + ORT_ENFORCE(shape != nullptr, "Missing shape tensor for reshaping."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape tensor must be on CPU."); + + // Calculate output's shape. + auto dst_shape = InferReshapeOutputShape(X, shape, allow_zero); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto Y = Tensor::Create(X->DataType(), dst_shape, alloc); + + // Do reshape. It's equivalent to memcpy. + ORT_ENFORCE(FuncReshape(cuda_kernel, ctx, X, shape, allow_zero, Y.get()).IsOK()); + return Y; +} + ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.h b/onnxruntime/core/providers/cuda/tensor/reshape.h index 01e933e65888f..8f33265071ed3 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.h +++ b/onnxruntime/core/providers/cuda/tensor/reshape.h @@ -10,6 +10,39 @@ namespace onnxruntime { namespace cuda { +// Deduce output shape from ONNX Reshape's inputs. +// +// Arguments: +// data_tensor_shape: The shape of the data tensor (i.e., 1st input). +// shape_span: Elements in the shape tensor (i.e., 2nd input). +// +// Returns: +// The output shape of this Reshape. No symbolic values such as "-1" or "0". +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, + const gsl::span& shape_span, + bool allow_zero); + +TensorShape InferReshapeOutputShape( + const Tensor* src, + const Tensor* shape, + bool allow_zero); + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y); + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero); + class Reshape final : public CudaKernel { public: Reshape(const OpKernelInfo& info) : CudaKernel(info), @@ -18,27 +51,11 @@ class Reshape final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { // Copy the second input tensor into the shape vector - const Tensor* shapeTensor = context->Input(1); - if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - if (shapeTensor->Shape().NumDimensions() != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); - auto data_span = shapeTensor->template DataAsSpan(); - TensorShapeVector shape(data_span.begin(), data_span.end()); - const Tensor* X = context->Input(0); - if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - const TensorShape& X_shape = X->Shape(); - - ReshapeHelper helper(X_shape, shape, allow_zero_); - - Tensor* Y = context->Output(0, TensorShape(shape)); - const void* source = X->DataRaw(); - void* target = Y->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_ENFORCE(context->GetComputeStream()); - ORT_RETURN_IF_ERROR(CopyTensor(*X, *Y, *context->GetComputeStream())); - } - - return Status::OK(); + const Tensor* data_tensor = context->Input(0); + const Tensor* shape_tensor = context->Input(1); + const auto target_shape = InferReshapeOutputShape(data_tensor, shape_tensor, allow_zero_); + Tensor* output_tensor = context->Output(0, target_shape); + return FuncReshape(this, context, data_tensor, shape_tensor, allow_zero_, output_tensor); } private: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index cd74e7fa92940..4f7ec188140b5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -103,6 +103,36 @@ namespace DmlGraphFusionHelper ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); } + std::tuple, std::vector, std::byte*, size_t> UnpackInitializer( + const onnxruntime::Graph& graph, + const ONNX_NAMESPACE::TensorProto* initializer) + { + std::unique_ptr unpackedTensor; + std::vector unpackedExternalTensor; + std::byte* tensorPtr = nullptr; + size_t tensorByteSize = 0; + + // The tensor may be stored as raw data or in typed fields. + if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); + tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); + tensorByteSize = unpackedExternalTensor.size(); + } + else if (initializer->has_raw_data()) + { + tensorPtr = (std::byte*)(initializer->raw_data().c_str()); + tensorByteSize = initializer->raw_data().size(); + } + else + { + std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); + tensorPtr = unpackedTensor.get(); + } + + return std::make_tuple(std::move(unpackedTensor), std::move(unpackedExternalTensor), tensorPtr, tensorByteSize); + } + void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, @@ -161,32 +191,11 @@ namespace DmlGraphFusionHelper auto iter = initializerNameToInitializerMap.find(subGraphInputArgNames[i]); if (iter != initializerNameToInitializerMap.end()) { - std::byte* tensorPtr = nullptr; - size_t tensorByteSize = 0; - std::vector unpackedExternalTensor; - - std::unique_ptr unpackedTensor; - - //auto& initializer = iter->second; auto* initializer = iter->second.first; + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, initializer); - // The tensor may be stored as raw data or in typed fields. - if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) - { - THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); - tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); - tensorByteSize = unpackedExternalTensor.size(); - } - else if (initializer->has_raw_data()) + if (initializer->data_location() != onnx::TensorProto_DataLocation_EXTERNAL && !initializer->has_raw_data()) { - tensorPtr = (std::byte*)(initializer->raw_data().c_str()); - tensorByteSize = initializer->raw_data().size(); - } - else - { - std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); - tensorPtr = unpackedTensor.get(); - // Free the initializer if this is the last usage of it. if (initializerToLastInputIndexMap[initializer] == i) { @@ -592,9 +601,11 @@ namespace DmlGraphFusionHelper for (auto& kvp : isInitializerTransferable) { + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, kvp.second.first); + ONNX_NAMESPACE::TensorProto tensorProto; tensorProto.set_data_type(kvp.second.first->data_type()); - tensorProto.set_raw_data(kvp.second.first->raw_data()); + tensorProto.set_raw_data(tensorPtr, tensorByteSize); tensorProto.set_name(kvp.second.first->name()); for (int i = 0; i < kvp.second.first->dims_size(); ++i) diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc index c3eab9dd8e557..54528011850be 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc @@ -12,6 +12,25 @@ namespace onnxruntime { namespace ort_dnnl { +inline static dnnl::memory::format_tag get_default_format(const dnnl::memory::dims& tensor_dims) { + switch (tensor_dims.size()) { + case 1: + return dnnl::memory::format_tag::a; + case 2: + return dnnl::memory::format_tag::ab; + case 3: + return dnnl::memory::format_tag::abc; + case 4: + return dnnl::memory::format_tag::abcd; + case 5: + return dnnl::memory::format_tag::abcde; + case 6: + return dnnl::memory::format_tag::abcdef; + default: + return dnnl::memory::format_tag::undef; + } +} + DnnlMatMul::DnnlMatMul() {} // This handles ONNX defined "MatMul" as well as two other variations of MatMul @@ -139,14 +158,14 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { if (transA || transBatchA) { src_md = transposedA_md; } else { - src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any); + src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), get_default_format(src_dims)); } dnnl::memory::desc weights_md; if (transB || transBatchB) { weights_md = transposedB_md; } else { - weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any); + weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), get_default_format(weights_dims)); } auto output_shape = src_dims; @@ -241,7 +260,7 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { attr.set_scales_mask(DNNL_ARG_SRC, 0); } - auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), get_default_format(output_shape)); auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, weights_md, dst_md, attr); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 78467b646b195..7e4c0dc8d7267 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -2,9 +2,7 @@ // Licensed under the MIT License #include -#include -#include -#include +#include #include "core/providers/shared_library/provider_api.h" #include "contexts.h" @@ -18,7 +16,8 @@ namespace openvino_ep { static std::unique_ptr g_global_context; GlobalContext& BackendManager::GetGlobalContext() { - // This is not thread safe to call for the first time, but it is first called on the main thread by the constructor so it is safe. + // This is not thread safe to call for the first time, + // but it is first called on the main thread by the constructor so it is safe. if (!g_global_context) g_global_context = std::make_unique(); return *g_global_context; @@ -88,7 +87,9 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, << "Backend created for graph " << subgraph_context_.subgraph_name; } } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. Initializing backend for graph " << subgraph_context_.subgraph_name; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " + << "Initializing backend for graph " + << subgraph_context_.subgraph_name; subgraph_context_.has_dynamic_input_shape = false; try { @@ -104,7 +105,7 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { bool has_batched_inputs = true; - for (int i = 0; i < (int)subgraph_context_.input_indexes.size(); i++) { + for (int i = 0; i < static_cast(subgraph_context_.input_indexes.size()); i++) { auto& input = model_proto.graph().input(subgraph_context_.input_indexes[i]); // Batch-process only raw image inputs (NCHW or NHWC layouts) @@ -215,7 +216,10 @@ BackendManager::ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_pr auto graph_proto = model_copy->mutable_graph(); for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { - auto g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + auto g_in_shape = graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->clear_dim(); const auto& shape = input_shapes[i]; for (size_t dim = 0, end = shape.size(); dim < end; dim++) { @@ -234,7 +238,11 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p auto graph_proto = model_copy->mutable_graph(); for (int i = 0; i < graph_proto->input_size(); i++) { - ONNX_NAMESPACE::TensorShapeProto* g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + ONNX_NAMESPACE::TensorShapeProto* g_in_shape = + graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->mutable_dim(0)->clear_dim_value(); g_in_shape->mutable_dim(0)->set_dim_value(1); } diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index c247ab60d3a6f..a177324b23f7d 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -3,6 +3,11 @@ #pragma once +#include +#include +#include +#include + #include "ov_interface.h" #include "contexts.h" #include "ibackend.h" @@ -13,7 +18,9 @@ namespace openvino_ep { // Singleton class that manages all the backends class BackendManager { public: - BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); + BackendManager(const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger); void Compute(OrtKernelContext* context); void ShutdownBackendManager(); static GlobalContext& GetGlobalContext(); @@ -21,7 +28,9 @@ class BackendManager { private: std::unique_ptr GetModelProtoFromFusedNode( - const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const; + const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger) const; bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index d49968cdb7f3d..d47c91dd46622 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -1,9 +1,7 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License -#include -#include -#include +#include #include #include @@ -58,7 +56,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext try { auto cnn_network = global_context.ie_core.ReadModel(model); if ((subgraph_context.precision == "FP16") && - (global_context.device_type.find("VPUX") == std::string::npos)) { + (global_context.device_type.find("NPU") == std::string::npos)) { // FP16 transformations ov::pass::ConvertFP32ToFP16 pass_obj; pass_obj.run_on_model(cnn_network); @@ -88,7 +86,8 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { - if (auto const_node = std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { + if (auto const_node = + std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { const_outputs_map[(*it)->get_friendly_name()] = const_node; results.erase(results.begin() + index); } @@ -254,7 +253,7 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName) { - long long totalTime = 0; + int64_t totalTime = 0; // Print performance counts stream << std::endl << "performance counts:" << std::endl diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index de78a150fe2dd..82b0351e87da5 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -4,9 +4,15 @@ #pragma once #define ORT_API_MANUAL_INIT +#include +#include +#include +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" #include "contexts.h" -#include #include "ov_interface.h" #ifdef _WIN32 #include @@ -57,7 +63,9 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); std::shared_ptr -CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, const SubGraphContext& subgraph_context, +CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, + const GlobalContext& global_context, + const SubGraphContext& subgraph_context, std::map>& const_outputs_map); void printPerformanceCounts(const std::vector& performanceMap, diff --git a/onnxruntime/core/providers/openvino/backends/backend_factory.cc b/onnxruntime/core/providers/openvino/backends/backend_factory.cc index c339f24e7022f..c586dd8b38af9 100644 --- a/onnxruntime/core/providers/openvino/backends/backend_factory.cc +++ b/onnxruntime/core/providers/openvino/backends/backend_factory.cc @@ -16,7 +16,7 @@ BackendFactory::MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto, const SubGraphContext& subgraph_context) { std::string type = global_context.device_type; if (type == "CPU" || type.find("GPU") != std::string::npos || - type.find("VPUX") != std::string::npos || + type.find("NPU") != std::string::npos || type.find("HETERO") != std::string::npos || type.find("MULTI") != std::string::npos || type.find("AUTO") != std::string::npos) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index f9517d7942664..09e1322ff59fb 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -6,10 +6,10 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" -// #include #include "basic_backend.h" #include "../backend_manager.h" @@ -57,33 +57,39 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, cl_context ctx = static_cast(global_context_.context); remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx); ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + model, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; #endif #endif } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } } catch (const char* msg) { @@ -127,10 +133,10 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } #endif #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - if (global_context_.device_type.find("VPUX") != std::string::npos) { + if (global_context_.device_type.find("NPU") != std::string::npos) { std::pair device_property; - device_property = std::make_pair("VPU_COMPILER_TYPE", "MLIR"); - device_config.emplace(ov::device::properties("VPUX", device_property)); + device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER"); + device_config.emplace(ov::device::properties("NPU", device_property)); } #endif } @@ -152,12 +158,12 @@ void BasicBackend::EnableCaching() { } void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { - if (global_context_.enable_opencl_throttling == true && global_context_.device_type.find("GPU") != std::string::npos) { + if (global_context_.enable_opencl_throttling == true && + global_context_.device_type.find("GPU") != std::string::npos) { LOGS_DEFAULT(INFO) << log_tag << "Enabled OpenCL queue throttling for GPU device"; std::pair device_property; device_property = std::make_pair("PLUGIN_THROTTLE", "1"); device_config.emplace(ov::device::properties("GPU_CONFIG_KEY", device_property)); - // device_config[GPU_CONFIG_KEY(PLUGIN_THROTTLE)] = "1"; } } @@ -187,7 +193,9 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && @@ -197,6 +205,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); auto tensor_size = tensor_shape.size(); + const char* tensor_data = tensor.GetTensorData(); auto tensor_iter = 0; ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { @@ -204,8 +213,16 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque tensor_iter += 1; } auto input = ie_cnn_network_->get_parameters().at(input_idx); - OVTensorPtr tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + OVTensorPtr tensor_ptr; + // avoid input copies on the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape, + (void*)tensor_data); + } else { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + } + try { infer_request->SetTensor(input_name, tensor_ptr); } catch (const char* msg) { @@ -251,7 +268,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } input_idx++; // Kernel Context Input Buffer @@ -264,9 +284,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create an Input Remote Blob auto input = ie_cnn_network_->get_parameters().at(0); - auto remote_blob = remote_context_->create_tensor(input->get_element_type(), input->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_blob); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_blob = remote_context_->create_tensor( + input->get_element_type(), input->get_shape(), *shared_buffer_const); + ov::Tensor tensor_remote = static_cast(remote_blob); + OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); infer_request->SetTensor(input_name, tensor_ptr); } else { OVTensorPtr graph_input_blob; @@ -295,7 +316,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe } } if (!output_name_found) { - throw std::string(log_tag + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); + throw std::string( + log_tag + + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); } size_t batch_size = 1; @@ -307,9 +331,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create a shared Blob, set the Infer Request Output Blob auto output = ie_cnn_network_->get_results().at(0); - auto remote_tensor = remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_tensor); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_tensor = + remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); + ov::Tensor tensor_t = static_cast(remote_tensor); + OVTensorPtr tensor_ptr = std::make_shared(tensor_t); try { infer_request->SetTensor(output_name, tensor_ptr); } catch (const char* msg) { @@ -364,7 +389,8 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe throw(msg); } size_t batch_size = 1; - auto output_tensor = GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); + auto output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); auto mem_info = output_tensor.GetTensorMemoryInfo(); if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; @@ -465,7 +491,8 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { #ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED if (openvino_ep::backend_utils::IsDebugEnabled()) { inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode - std::string& hw_target = (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; + std::string& hw_target = + (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; printPerformanceCounts(infer_request, std::cout, hw_target); } #endif diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 2f1d603640809..6eda641451a72 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -6,16 +6,17 @@ #include #define ORT_API_MANUAL_INIT -#include "core/session/onnxruntime_cxx_api.h" -#include "core/providers/openvino/contexts.h" -#include "core/providers/openvino/ibackend.h" -#include "core/providers/openvino/ov_interface.h" #include #include #include #include #include +#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/openvino/contexts.h" +#include "core/providers/openvino/ibackend.h" +#include "core/providers/openvino/ov_interface.h" + namespace onnxruntime { namespace openvino_ep { @@ -29,7 +30,7 @@ class BasicBackend : public IBackend { void Infer(OrtKernelContext* context) override; private: - bool ImportBlob(std::string hw_target, bool vpu_status); + bool ImportBlob(std::string hw_target, bool npu_status); void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index b61dcf8ca4922..29233e72c33b9 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -3,6 +3,9 @@ #pragma once +#include +#include +#include #include "ov_interface.h" namespace onnxruntime { @@ -12,7 +15,7 @@ namespace openvino_ep { struct GlobalContext { OVCore ie_core; bool is_wholly_supported_graph = false; - bool enable_vpu_fast_compile = false; + bool enable_npu_fast_compile = false; bool enable_opencl_throttling = false; bool enable_dynamic_shapes = false; size_t num_of_threads; @@ -34,7 +37,7 @@ struct GlobalContext { struct SubGraphContext { bool has_dynamic_input_shape = false; bool enable_batching = false; - bool set_vpu_config = false; + bool set_npu_config = false; bool is_constant = false; void* context = 0; std::string subgraph_name; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 990809926299e..a4c6b0f851c04 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -17,17 +17,18 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv openvino_ep::BackendManager::GetGlobalContext().device_type = info.device_type_; openvino_ep::BackendManager::GetGlobalContext().precision_str = info.precision_; - openvino_ep::BackendManager::GetGlobalContext().enable_vpu_fast_compile = info.enable_vpu_fast_compile_; + openvino_ep::BackendManager::GetGlobalContext().enable_npu_fast_compile = info.enable_npu_fast_compile_; openvino_ep::BackendManager::GetGlobalContext().cache_dir = info.cache_dir_; openvino_ep::BackendManager::GetGlobalContext().num_streams = info.num_streams_; openvino_ep::BackendManager::GetGlobalContext().context = info.context_; openvino_ep::BackendManager::GetGlobalContext().enable_opencl_throttling = info.enable_opencl_throttling_; openvino_ep::BackendManager::GetGlobalContext().enable_dynamic_shapes = info.enable_dynamic_shapes_; - if ((int)info.num_of_threads_ <= 0) { + if (static_cast(info.num_of_threads_) <= 0) { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = 8; - } else if ((int)info.num_of_threads_ > 8) { - std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; + } else if (static_cast(info.num_of_threads_) > 8) { + std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; ORT_THROW(err_msg); } else { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = info.num_of_threads_; @@ -56,7 +57,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv device_found = true; break; } - if (info.device_type_.find("VPUX") != std::string::npos && (info.precision_ == "FP16" || info.precision_ == "U8")) { + if ((info.device_type_.find("NPU") != std::string::npos) && + (info.precision_ == "FP16" || info.precision_ == "U8")) { device_found = true; break; } @@ -109,11 +111,14 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::BackendManager::GetGlobalContext().onnx_model_name = graph_viewer.Name(); #ifdef _WIN32 std::wstring onnx_path = graph_viewer.ModelPath().ToPathString(); - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = std::string(onnx_path.begin(), onnx_path.end()); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + std::string(onnx_path.begin(), onnx_path.end()); #else - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = graph_viewer.ModelPath().ToPathString(); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + graph_viewer.ModelPath().ToPathString(); #endif - openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = graph_viewer.DomainToVersionMap().at(kOnnxDomain); + openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = + graph_viewer.DomainToVersionMap().at(kOnnxDomain); #if defined(OPENVINO_2022_1) openvino_ep::GetCapability obj(graph_viewer, @@ -151,7 +156,8 @@ common::Status OpenVINOExecutionProvider::Compile( openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; - std::shared_ptr backend_manager = std::make_shared(fused_node, graph_body_viewer, *GetLogger()); + std::shared_ptr backend_manager = + std::make_shared(fused_node, graph_body_viewer, *GetLogger()); compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index a4fc09362fa23..3b56b54410e40 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -3,19 +3,28 @@ #pragma once -#include "backend_manager.h" #include #include #include +#include +#include +#include + +#include "backend_manager.h" namespace onnxruntime { static void print_build_options() { std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; - std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority you want to build" << std::endl; - std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build "; - std::cout << "are ['CPU','GPU','VPUX']" << std::endl; - std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" << std::endl; + std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " + << "you want to build" + << std::endl; + std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build " + << "are ['CPU','GPU']" + << std::endl; + std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. " + << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" + << std::endl; } static std::vector split(const std::string& s, char delim) { @@ -39,7 +48,7 @@ static std::vector parseDevices(const std::string& device_string) { print_build_options(); ORT_THROW("Invalid device string: " + device_string); } - std::vector dev_options = {"CPU", "GPU", "VPUX"}; + std::vector dev_options = {"CPU", "GPU"}; for (std::string dev : devices) { if (!std::count(dev_options.begin(), dev_options.end(), dev)) { print_build_options(); @@ -53,7 +62,7 @@ static std::vector parseDevices(const std::string& device_string) { struct OpenVINOExecutionProviderInfo { std::string device_type_; std::string precision_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -62,11 +71,18 @@ struct OpenVINOExecutionProviderInfo { bool enable_opencl_throttling_; bool enable_dynamic_shapes_; - explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_vpu_fast_compile, std::string dev_id, + explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_npu_fast_compile, std::string dev_id, size_t num_of_threads, std::string cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), device_id_(dev_id), num_of_threads_(num_of_threads), cache_dir_(cache_dir), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + device_id_(dev_id), + num_of_threads_(num_of_threads), + cache_dir_(cache_dir), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { if (dev_type == "") { LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" << "No runtime device selection option provided."; @@ -82,11 +98,11 @@ struct OpenVINOExecutionProviderInfo { #elif defined OPENVINO_CONFIG_GPU_FP16 device_type_ = "GPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_FP16 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_FP16 + device_type_ = "NPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_U8 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_U8 + device_type_ = "NPU"; precision_ = "U8"; #elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO #ifdef DEVICE_NAME @@ -126,11 +142,11 @@ struct OpenVINOExecutionProviderInfo { } else if (dev_type == "GPU.1_FP16") { device_type_ = "GPU.1"; precision_ = "FP16"; - } else if (dev_type == "VPUX_FP16") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_FP16") { + device_type_ = "NPU"; precision_ = "FP16"; - } else if (dev_type == "VPUX_U8") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_U8") { + device_type_ = "NPU"; precision_ = "U8"; } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0) { std::vector devices = parseDevices(dev_type); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 95b39bcc05983..fbb89710c8008 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -8,11 +8,16 @@ namespace onnxruntime { struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(const char* device_type, bool enable_vpu_fast_compile, + OpenVINOProviderFactory(const char* device_type, bool enable_npu_fast_compile, const char* device_id, size_t num_of_threads, const char* cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), num_of_threads_(num_of_threads), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + num_of_threads_(num_of_threads), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { device_type_ = (device_type == nullptr) ? "" : device_type; device_id_ = (device_id == nullptr) ? "" : device_id; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; @@ -24,7 +29,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { private: std::string device_type_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -35,7 +40,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { - OpenVINOExecutionProviderInfo info(device_type_, enable_vpu_fast_compile_, device_id_, num_of_threads_, + OpenVINOExecutionProviderInfo info(device_type_, enable_npu_fast_compile_, device_id_, num_of_threads_, cache_dir_, num_streams_, context_, enable_opencl_throttling_, enable_dynamic_shapes_); return std::make_unique(info); @@ -59,17 +64,18 @@ struct OpenVINO_Provider : Provider { std::string device_type = ""; // [device_type]: Overrides the accelerator hardware type and precision // with these values at runtime. - bool enable_vpu_fast_compile = false; // [enable_vpu_fast_compile]: Fast-compile may be optionally enabled to - // speeds up the model's compilation to VPU device specific format. + bool enable_npu_fast_compile = false; // [enable_npu_fast_compile]: Fast-compile may be optionally enabled to + // speeds up the model's compilation to NPU device specific format. const char* device_id = ""; // [device_id]: Selects a particular hardware device for inference. - size_t num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of + int num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of // threads with this value at runtime. const char* cache_dir = ""; // [cache_dir]: specify the path to // dump and load the blobs for the model caching/kernel caching (GPU) // feature. If blob files are already present, it will be directly loaded. int num_streams = 1; // [num_streams]: Option that specifies the number of parallel inference // requests to be processed on a given `device_type`. Overrides the - // accelerator default value of number of streams with this value at runtime. + // accelerator default value of number of streams + // with this value at runtime. bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU // device (Reduces CPU Utilization when using GPU) bool enable_dynamic_shapes = false; // [enable_dynamic_shapes]: Enables Dynamic Shapes feature for CPU device) @@ -80,14 +86,15 @@ struct OpenVINO_Provider : Provider { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (!((ov_supported_device_types.find(device_type) != ov_supported_device_types.end()) || - (device_type.find("HETERO:") == 0) || (device_type.find("MULTI:") == 0) || (device_type.find("AUTO:") == 0))) { + (device_type.find("HETERO:") == 0) || + (device_type.find("MULTI:") == 0) || + (device_type.find("AUTO:") == 0))) { ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } @@ -97,30 +104,37 @@ struct OpenVINO_Provider : Provider { if (provider_options_map.find("cache_dir") != provider_options_map.end()) { cache_dir = provider_options_map.at("cache_dir").c_str(); } + if (provider_options_map.find("context") != provider_options_map.end()) { - context = (void*)provider_options_map.at("context").c_str(); + std::string str = provider_options_map.at("context"); + uint64_t number = std::strtoull(str.c_str(), nullptr, 16); + context = reinterpret_cast(number); } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_threads' should be in the positive range.\n " + << "Executing with num_threads=1"; } } if (provider_options_map.find("num_streams") != provider_options_map.end()) { num_streams = std::stoi(provider_options_map.at("num_streams")); - if (num_streams <= 0 && num_streams > 8) { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_streams' should be in the range of 1-8 \n"); + if (num_streams <= 0) { + num_streams = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_streams' should be in the range of 1-8.\n " + << "Executing with num_streams=1"; } } std::string bool_flag = ""; - if (provider_options_map.find("enable_vpu_fast_compile") != provider_options_map.end()) { - bool_flag = provider_options_map.at("enable_vpu_fast_compile"); + if (provider_options_map.find("enable_npu_fast_compile") != provider_options_map.end()) { + bool_flag = provider_options_map.at("enable_npu_fast_compile"); if (bool_flag == "true" || bool_flag == "True") - enable_vpu_fast_compile = true; + enable_npu_fast_compile = true; else if (bool_flag == "false" || bool_flag == "False") - enable_vpu_fast_compile = false; + enable_npu_fast_compile = false; bool_flag = ""; } @@ -141,7 +155,7 @@ struct OpenVINO_Provider : Provider { enable_dynamic_shapes = false; } return std::make_shared(const_cast(device_type.c_str()), - enable_vpu_fast_compile, + enable_npu_fast_compile, device_id, num_of_threads, cache_dir, @@ -157,7 +171,6 @@ struct OpenVINO_Provider : Provider { void Shutdown() override { openvino_ep::BackendManager::ReleaseGlobalContext(); } - } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 3914488fc523b..d2ce378c97e02 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -29,7 +29,10 @@ std::shared_ptr OVCore::ReadModel(const std::string& model) const { } } -OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(ie_cnn_network, hw_target, device_config); @@ -43,7 +46,10 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std } #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) -OVExeNetwork OVCore::LoadNetwork(const std::string& model, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(const std::string& model, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index ed9583033ab34..935ac8f68411d 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -4,6 +4,7 @@ #pragma once #include +#include #if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) #define OV_API_20 @@ -43,9 +44,15 @@ class OVCore { public: std::shared_ptr ReadModel(const std::string& model_stream) const; - OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - OVExeNetwork LoadNetwork(const std::string& model_stream, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(const std::string& model_stream, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #endif void SetCache(std::string cache_dir_path); #ifdef IO_BUFFER_ENABLED @@ -62,7 +69,7 @@ class OVExeNetwork { ov::CompiledModel obj; public: - OVExeNetwork(ov::CompiledModel md) { obj = md; } + explicit OVExeNetwork(ov::CompiledModel md) { obj = md; } OVExeNetwork() { obj = ov::CompiledModel(); } ov::CompiledModel& Get() { return obj; } OVInferRequest CreateInferRequest(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h index b76d1cf534c2a..5bcf9d68cd94e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h +++ b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include "data_ops.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 171dd45c508cc..b030efa238209 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -24,7 +24,8 @@ namespace openvino_ep { // Constructor GetCapability::GetCapability(const GraphViewer& graph_viewer_param, std::string device_type_param, - const std::string version_param) : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { + const std::string version_param) + : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { if (version_param == "V_2022_1") { data_ops_ = new DataOps(graph_viewer_, V_2022_1, device_type_); } else if (version_param == "V_2022_2") { @@ -114,11 +115,11 @@ std::vector> GetCapability::Execute() { } openvino_ep::BackendManager::GetGlobalContext().is_wholly_supported_graph = true; - } else { // unsupported_nodes_idx.empty() - + } else { // unsupported_nodes_idx.empty() #if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time LOGS_DEFAULT(INFO) << "[OpenVINO-EP] DISABLE_GRAPH_PARTITION option is set"; - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, so making the full model fall back to default CPU Execution Provider"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, " + << "so making the full model fall back to default CPU Execution Provider"; return result; #endif @@ -159,7 +160,13 @@ std::vector> GetCapability::Execute() { std::vector cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs; - GetInputsOutputsOfCluster(graph_viewer_, this_cluster, ng_required_initializers, cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs); + GetInputsOutputsOfCluster(graph_viewer_, + this_cluster, + ng_required_initializers, + cluster_graph_inputs, + cluster_inputs, + const_inputs, + cluster_outputs); bool omit_subgraph = false; // Omitting zero dim subgraphs diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 70118c94f9ff8..a5a0faa3a8f24 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -2,11 +2,15 @@ // Licensed under the MIT License #include +#include +#include +#include +#include +#include + #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" #include "../backend_manager.h" -#include -#include #include "data_ops.h" #include "capabilities.h" #include "utils.h" @@ -72,269 +76,355 @@ std::set ops_supported_as_function = { std::vector supported_op_mode = { {"Abs", V_2020_4, {"CPU", "GPU"}}, - {"Abs", V_2023_0, {"VPUX"}}, + {"Abs", V_2023_0, {"NPU"}}, {"Acos", V_2020_4, {"CPU"}}, {"Acos", V_2022_1, {"GPU"}}, + {"Acos", V_2023_1, {"NPU"}}, {"Acosh", V_2020_4, {"CPU"}}, {"Acosh", V_2022_1, {"GPU"}}, + {"Acosh", V_2023_1, {"NPU"}}, {"Add", V_2020_4, {"CPU", "GPU"}}, - {"Add", V_2023_0, {"VPUX"}}, + {"Add", V_2023_0, {"NPU"}}, {"And", V_2020_4, {"CPU", "GPU"}}, + {"And", V_2023_1, {"NPU"}}, {"ArgMax", V_2020_4, {"CPU"}}, {"ArgMax", V_2021_1, {"GPU"}}, {"ArgMin", V_2020_4, {"CPU"}}, {"ArgMin", V_2022_1, {"GPU"}}, {"Asin", V_2020_4, {"CPU", "GPU"}}, + {"Asin", V_2023_1, {"NPU"}}, {"Asinh", V_2020_4, {"CPU", "GPU"}}, + {"Asinh", V_2023_1, {"NPU"}}, {"Atan", V_2020_4, {"CPU", "GPU"}}, + {"Atan", V_2023_1, {"NPU"}}, {"Atanh", V_2020_4, {"CPU"}}, {"Atanh", V_2022_1, {"GPU"}}, + {"Atanh", V_2023_1, {"NPU"}}, {"AveragePool", V_2020_4, {"CPU", "GPU"}}, - {"AveragePool", V_2023_0, {"VPUX"}}, + {"AveragePool", V_2023_0, {"NPU"}}, {"BatchNormalization", V_2020_4, {"CPU", "GPU"}}, - {"BatchNormalization", V_2023_0, {"VPUX"}}, + {"BatchNormalization", V_2023_0, {"NPU"}}, {"BitShift", V_2022_1, {"CPU"}}, + {"BitShift", V_2023_1, {"NPU"}}, {"Cast", V_2020_4, {"CPU", "GPU"}}, - {"Cast", V_2023_0, {"VPUX"}}, + {"Cast", V_2023_0, {"NPU"}}, + {"CastLike", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Ceil", V_2020_4, {"GPU"}}, {"Ceil", V_2021_4, {"CPU"}}, + {"Ceil", V_2023_1, {"NPU"}}, {"Celu", V_2022_1, {"CPU", "GPU"}}, {"Clip", V_2020_4, {"CPU", "GPU"}}, - {"Clip", V_2023_0, {"VPUX"}}, + {"Clip", V_2023_0, {"NPU"}}, + {"Compress", V_2023_1, {"CPU", "GPU"}}, {"Concat", V_2020_4, {"CPU", "GPU"}}, - {"Concat", V_2023_0, {"VPUX"}}, + {"Concat", V_2023_0, {"NPU"}}, {"Constant", V_2020_4, {"CPU", "GPU"}}, - {"Constant", V_2023_0, {"VPUX"}}, + {"Constant", V_2023_0, {"NPU"}}, {"ConstantOfShape", V_2020_4, {"CPU", "GPU"}}, - {"ConstantOfShape", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op in the plugin. + {"ConstantOfShape", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op in the plugin. {"Conv", V_2020_4, {"CPU", "GPU"}}, - {"Conv", V_2023_0, {"VPUX"}}, + {"Conv", V_2023_0, {"NPU"}}, {"ConvInteger", V_2022_1, {"CPU", "GPU"}}, + {"ConvInteger", V_2023_1, {"NPU"}}, {"ConvTranspose", V_2020_4, {"CPU", "GPU"}}, + {"ConvTranspose", V_2023_1, {"NPU"}}, {"Cos", V_2020_4, {"CPU"}}, {"Cos", V_2022_1, {"GPU"}}, - {"Cos", V_2023_0, {"VPUX"}}, + {"Cos", V_2023_0, {"NPU"}}, {"Cosh", V_2020_4, {"CPU"}}, {"Cosh", V_2022_1, {"GPU"}}, + {"Cosh", V_2023_1, {"NPU"}}, {"CumSum", V_2022_1, {"CPU", "GPU"}}, - {"CumSum", V_2023_0, {"VPUX"}}, + {"CumSum", V_2023_0, {"NPU"}}, {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, - {"DepthToSpace", V_2023_0, {"VPUX"}}, + {"DepthToSpace", V_2023_0, {"NPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"DequantizeLinear", V_2023_0, {"VPUX"}}, + {"DequantizeLinear", V_2023_0, {"NPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, - {"Div", V_2023_0, {"VPUX"}}, + {"Div", V_2023_0, {"NPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, - {"Dropout", V_2023_0, {"VPUX"}}, + {"Dropout", V_2023_0, {"NPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, - {"Elu", V_2023_0, {"VPUX"}}, + {"Elu", V_2023_0, {"NPU"}}, // {"Einsum", V_2023_0, {"CPU", "GPU"}}, {"Equal", V_2020_4, {"CPU", "GPU"}}, - {"Equal", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Equal", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Erf", V_2020_4, {"CPU", "GPU"}}, - {"Erf", V_2023_0, {"VPUX"}}, + {"Erf", V_2023_0, {"NPU"}}, {"Exp", V_2020_4, {"CPU", "GPU"}}, - {"Exp", V_2023_0, {"VPUX"}}, + {"Exp", V_2023_0, {"NPU"}}, {"Expand", V_2022_1, {"CPU", "GPU"}}, - {"Expand", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op and multiply op in the plugin. + {"Expand", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op and multiply op in the plugin. {"EyeLike", V_2022_1, {"CPU"}}, - {"EyeLike", V_2023_0, {"VPUX"}}, // NoOP + {"EyeLike", V_2023_0, {"NPU"}}, // NoOP {"Flatten", V_2020_4, {"CPU", "GPU"}}, - {"Flatten", V_2023_0, {"VPUX"}}, + {"Flatten", V_2023_0, {"NPU"}}, {"Floor", V_2020_4, {"CPU", "GPU"}}, + {"Floor", V_2023_1, {"NPU"}}, {"Gather", V_2020_4, {"CPU", "GPU"}}, - {"Gather", V_2023_0, {"VPUX"}}, + {"Gather", V_2023_0, {"NPU"}}, {"GatherElements", V_2022_2, {"CPU", "GPU"}}, + {"GatherElements", V_2023_1, {"NPU"}}, {"GatherND", V_2021_4, {"CPU", "GPU"}}, + {"GatherND", V_2023_1, {"NPU"}}, {"Gemm", V_2020_4, {"CPU", "GPU"}}, - {"Gemm", V_2023_0, {"VPUX"}}, + {"Gemm", V_2023_0, {"NPU"}}, {"GlobalAveragePool", V_2020_4, {"CPU", "GPU"}}, - {"GlobalAveragePool", V_2023_0, {"VPUX"}}, + {"GlobalAveragePool", V_2023_0, {"NPU"}}, {"GlobalLpPool", V_2020_4, {"CPU", "GPU"}}, + {"GlobalLpPool", V_2023_1, {"NPU"}}, {"GlobalMaxPool", V_2022_1, {"CPU", "GPU"}}, + {"GlobalMaxPool", V_2023_1, {"NPU"}}, {"Greater", V_2020_4, {"CPU", "GPU"}}, - {"Greater", V_2023_0, {"VPUX"}}, + {"Greater", V_2023_0, {"NPU"}}, {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"GreaterOrEqual", V_2023_0, {"VPUX"}}, + {"GreaterOrEqual", V_2023_0, {"NPU"}}, {"GridSample", V_2022_3, {"CPU"}}, {"GridSample", V_2023_0, {"GPU"}}, + {"GridSample", V_2023_1, {"NPU"}}, + {"HardMax", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Identity", V_2020_4, {"CPU", "GPU"}}, - {"Identity", V_2023_0, {"VPUX"}}, // NoOP + {"Identity", V_2023_0, {"NPU"}}, // NoOP {"If", V_2022_3, {"CPU", "GPU"}}, + {"If", V_2023_1, {"NPU"}}, {"ImageScaler", V_2022_1, {"CPU", "GPU"}}, - {"ImageScaler", V_2023_0, {"VPUX"}}, + {"ImageScaler", V_2023_0, {"NPU"}}, {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, - {"InstanceNormalization", V_2023_0, {"VPUX"}}, + {"InstanceNormalization", V_2023_0, {"NPU"}}, {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, + {"HardSigmoid", V_2023_1, {"NPU"}}, {"HardMax", V_2022_1, {"CPU", "GPU"}}, {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, - {"LeakyRelu", V_2023_0, {"VPUX"}}, + {"LeakyRelu", V_2023_0, {"NPU"}}, {"Less", V_2020_4, {"CPU", "GPU"}}, - {"Less", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Less", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"LessOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"LessOrEqual", V_2023_0, {"VPUX"}}, + {"LessOrEqual", V_2023_0, {"NPU"}}, {"Log", V_2020_4, {"CPU", "GPU"}}, - {"Log", V_2023_0, {"VPUX"}}, + {"Log", V_2023_0, {"NPU"}}, {"LogSoftMax", V_2022_1, {"CPU", "GPU"}}, {"Loop", V_2021_4, {"CPU", "GPU"}}, + {"LpNormalization", V_2023_1, {"CPU", "GPU", "NPU"}}, + {"LpPool", V_2023_1, {"CPU", "GPU", "NPU"}}, {"LRN", V_2020_4, {"CPU", "GPU"}}, - {"LRN", V_2023_0, {"VPUX"}}, + {"LRN", V_2023_0, {"NPU"}}, {"LSTM", V_2020_4, {"CPU", "GPU"}}, + {"LSTM", V_2023_1, {"NPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, - {"MatMul", V_2023_0, {"VPUX"}}, + {"MatMul", V_2023_0, {"NPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2023_1, {"NPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, - {"Max", V_2023_0, {"VPUX"}}, + {"Max", V_2023_0, {"NPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, - {"MaxPool", V_2023_0, {"VPUX"}}, + {"MaxPool", V_2023_0, {"NPU"}}, {"Mean", V_2020_4, {"CPU", "GPU"}}, - {"Mean", V_2023_0, {"VPUX"}}, + {"Mean", V_2023_0, {"NPU"}}, {"MeanVarianceNormalization", V_2022_1, {"CPU", "GPU"}}, + {"MeanVarianceNormalization", V_2023_1, {"NPU"}}, {"Min", V_2020_4, {"CPU", "GPU"}}, - {"Min", V_2023_0, {"VPUX"}}, + {"Min", V_2023_0, {"NPU"}}, {"Mod", V_2022_1, {"CPU", "GPU"}}, {"Mul", V_2020_4, {"CPU", "GPU"}}, - {"Mul", V_2023_0, {"VPUX"}}, + {"Mul", V_2023_0, {"NPU"}}, {"Neg", V_2020_4, {"CPU", "GPU"}}, - {"Neg", V_2023_0, {"VPUX"}}, + {"Neg", V_2023_0, {"NPU"}}, {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}}, + {"NonMaxSuppression", V_2023_1, {"NPU"}}, {"NonZero", V_2021_1, {"CPU"}}, {"NonZero", V_2023_0, {"GPU"}}, {"Not", V_2021_1, {"CPU", "GPU"}}, {"Not", V_2020_4, {"CPU", "GPU"}}, + {"Not", V_2023_1, {"NPU"}}, {"OneHot", V_2020_4, {"CPU", "GPU"}}, + {"OneHot", V_2023_1, {"NPU"}}, {"Or", V_2022_1, {"CPU", "GPU"}}, + {"Or", V_2023_1, {"NPU"}}, {"Pad", V_2020_4, {"CPU", "GPU"}}, - {"Pad", V_2023_0, {"VPUX"}}, + {"Pad", V_2023_0, {"NPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, - {"Pow", V_2023_0, {"VPUX"}}, + {"Pow", V_2023_0, {"NPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"PRelu", V_2023_0, {"VPUX"}}, + {"PRelu", V_2023_0, {"NPU"}}, {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2023_1, {"NPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"QuantizeLinear", V_2023_0, {"VPUX"}}, + {"QuantizeLinear", V_2023_0, {"NPU"}}, + {"RNN", V_2023_1, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_1, {"NPU"}}, {"RandomNormal", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormal", V_2023_1, {"NPU"}}, {"Range", V_2022_1, {"CPU", "GPU"}}, - {"Range", V_2023_0, {"VPUX"}}, + {"Range", V_2023_0, {"NPU"}}, {"Reciprocal", V_2020_4, {"CPU", "GPU"}}, - {"Reciprocal", V_2023_0, {"VPUX"}}, + {"Reciprocal", V_2023_0, {"NPU"}}, {"ReduceL1", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL1", V_2023_1, {"NPU"}}, {"ReduceL2", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL2", V_2023_1, {"NPU"}}, {"ReduceLogSum", V_2020_4, {"CPU"}}, {"ReduceLogSum", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSum", V_2023_1, {"NPU"}}, {"ReduceLogSumExp", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSumExp", V_2023_1, {"NPU"}}, {"ReduceMax", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMax", V_2023_1, {"NPU"}}, {"ReduceMean", V_2020_4, {"CPU", "GPU"}}, - {"ReduceMean", V_2023_0, {"VPUX"}}, + {"ReduceMean", V_2023_0, {"NPU"}}, {"ReduceMin", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMin", V_2023_1, {"NPU"}}, {"ReduceProd", V_2020_4, {"CPU"}}, {"ReduceProd", V_2022_1, {"GPU"}}, + {"ReduceProd", V_2023_1, {"NPU"}}, {"ReduceSum", V_2020_4, {"CPU", "GPU"}}, + // {"ReduceSum", V_2023_1, {"NPU"}}, {"ReduceSumSquare", V_2020_4, {"CPU"}}, {"ReduceSumSquare", V_2022_1, {"CPU", "GPU"}}, + {"ReduceSumSquare", V_2023_1, {"NPU"}}, {"Relu", V_2020_4, {"CPU", "GPU"}}, - {"Relu", V_2023_0, {"VPUX"}}, + {"Relu", V_2023_0, {"NPU"}}, {"Resize", V_2020_4, {"CPU"}}, {"Resize", V_2022_1, {"GPU"}}, + {"Resize", V_2023_1, {"NPU"}}, {"Reshape", V_2020_4, {"CPU", "GPU"}}, - {"Reshape", V_2023_0, {"VPUX"}}, + {"Reshape", V_2023_0, {"NPU"}}, {"ReverseSequence", V_2022_1, {"CPU", "GPU"}}, {"RoiAlign", V_2021_1, {"CPU", "GPU"}}, + {"RoiAlign", V_2023_1, {"NPU"}}, {"Round", V_2021_4, {"CPU", "GPU"}}, + {"Round", V_2023_1, {"NPU"}}, {"Scatter", V_2022_1, {"CPU", "GPU"}}, + {"Scatter", V_2023_1, {"NPU"}}, {"ScatterElements", V_2022_1, {"CPU", "GPU"}}, + {"ScatterElements", V_2023_1, {"NPU"}}, {"ScatterND", V_2022_1, {"CPU", "GPU"}}, + {"ScatterND", V_2023_1, {"NPU"}}, {"Selu", V_2020_4, {"CPU", "GPU"}}, + {"Selu", V_2023_1, {"NPU"}}, {"Shape", V_2020_4, {"CPU", "GPU"}}, - {"Shape", V_2023_0, {"VPUX"}}, + {"Shape", V_2023_0, {"NPU"}}, {"Shrink", V_2022_1, {"CPU", "GPU"}}, - {"Shrink", V_2023_0, {"VPUX"}}, + {"Shrink", V_2023_0, {"NPU"}}, {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, - {"Sigmoid", V_2023_0, {"VPUX"}}, + {"Sigmoid", V_2023_0, {"NPU"}}, {"Sign", V_2020_4, {"CPU"}}, {"Sign", V_2022_1, {"GPU"}}, - {"Sign", V_2023_0, {"VPUX"}}, + {"Sign", V_2023_0, {"NPU"}}, {"Sin", V_2022_1, {"CPU", "GPU"}}, - {"Sin", V_2023_0, {"VPUX"}}, + {"Sin", V_2023_0, {"NPU"}}, {"Sinh", V_2020_4, {"CPU"}}, + {"Sinh", V_2023_1, {"NPU"}}, {"Size", V_2022_1, {"CPU", "GPU"}}, + {"Size", V_2023_1, {"NPU"}}, {"Slice", V_2020_4, {"CPU", "GPU"}}, - {"Slice", V_2023_0, {"VPUX"}}, + {"Slice", V_2023_0, {"NPU"}}, {"Softmax", V_2020_4, {"CPU", "GPU"}}, - {"Softmax", V_2023_0, {"VPUX"}}, + {"Softmax", V_2023_0, {"NPU"}}, {"Softplus", V_2022_1, {"CPU", "GPU"}}, - {"Softplus", V_2023_0, {"VPUX"}}, + {"Softplus", V_2023_0, {"NPU"}}, {"Softsign", V_2022_1, {"CPU", "GPU"}}, {"SpaceToDepth", V_2020_4, {"CPU", "GPU"}}, - {"SpaceToDepth", V_2023_0, {"VPUX"}}, + {"SpaceToDepth", V_2023_0, {"NPU"}}, {"Split", V_2020_4, {"CPU", "GPU"}}, - {"Split", V_2023_0, {"VPUX"}}, + {"Split", V_2023_0, {"NPU"}}, {"Sqrt", V_2020_4, {"CPU", "GPU"}}, - {"Sqrt", V_2023_0, {"VPUX"}}, + {"Sqrt", V_2023_0, {"NPU"}}, {"Squeeze", V_2020_4, {"CPU", "GPU"}}, - {"Squeeze", V_2023_0, {"VPUX"}}, + {"Squeeze", V_2023_0, {"NPU"}}, {"Softsign", V_2020_4, {"CPU"}}, {"Sub", V_2020_4, {"CPU", "GPU"}}, - {"Sub", V_2023_0, {"VPUX"}}, + {"Sub", V_2023_0, {"NPU"}}, {"Sum", V_2020_4, {"CPU", "GPU"}}, - {"Sum", V_2023_0, {"VPUX"}}, + {"Sum", V_2023_0, {"NPU"}}, {"Tan", V_2020_4, {"CPU", "GPU"}}, + {"Tan", V_2023_1, {"NPU"}}, {"Tanh", V_2020_4, {"CPU", "GPU"}}, - {"Tanh", V_2023_0, {"VPUX"}}, + {"Tanh", V_2023_0, {"NPU"}}, {"ThresholdedRelu", V_2022_1, {"CPU", "GPU"}}, - {"ThresholdedRelu", V_2023_0, {"VPUX"}}, + {"ThresholdedRelu", V_2023_0, {"NPU"}}, {"Tile", V_2021_3, {"CPU", "GPU"}}, - {"Tile", V_2023_0, {"VPUX"}}, + {"Tile", V_2023_0, {"NPU"}}, {"Transpose", V_2020_4, {"CPU", "GPU"}}, - {"Transpose", V_2023_0, {"VPUX"}}, + {"Transpose", V_2023_0, {"NPU"}}, {"Trilu", V_2023_0, {"CPU", "GPU"}}, + {"Trilu", V_2023_1, {"NPU"}}, {"TopK", V_2020_4, {"CPU", "GPU"}}, - {"TopK", V_2023_0, {"VPUX"}}, + {"TopK", V_2023_0, {"NPU"}}, + {"Upsample", V_2020_4, {"CPU", "GPU"}}, {"Unsqueeze", V_2020_4, {"CPU", "GPU"}}, - {"Unsqueeze", V_2023_0, {"VPUX"}}, - {"Upsample", V_2021_1, {"CPU"}}, - {"Upsample", V_2021_4, {"GPU"}}, - {"Upsample", V_2023_0, {"VPUX"}}, + {"Unsqueeze", V_2023_0, {"NPU"}}, {"Where", V_2022_1, {"CPU", "GPU"}}, - {"Where", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Where", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Xor", V_2022_1, {"CPU", "GPU"}}, + {"Xor", V_2023_1, {"NPU"}}, }; void DataOps::populate_types_supported() { - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_initializer_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_vpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_npu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_cpu_.insert(std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_cpu_.insert( + std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_gpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_gpu_.insert(std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_gpu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_gpu_.insert( + std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); } void DataOps::populate_op_mode_supported() { @@ -349,10 +439,10 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); - no_dimension_supported_.push_back({"Greater", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Greater", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); - no_dimension_supported_.push_back({"Max", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Max", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"QuantizeLinear", V_2021_4, {"All"}}); @@ -382,11 +472,14 @@ void DataOps::populate_op_mode_supported() { { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { - // Abs is not supproted with INT8 or INT32 as input data type on GPU - if (device_id_.find("GPU") != std::string::npos) { + // Abs is not supproted with INT8 or INT32 as input data type on GPU and NPU + if ((device_id_.find("GPU") != std::string::npos) || + (device_id_.find("NPU") != std::string::npos)) { for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || - node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || + node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -399,11 +492,14 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // tensor type does not support select last index auto& attributes = node->GetAttributes(); - auto last_index_arg = attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() : 0; + auto last_index_arg = + attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() + : 0; if (last_index_arg != 0) return true; // tensor type supports float as input for argmax and argmin - if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) + if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) return true; return false; }}; @@ -415,7 +511,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { // int64 data type is not supported on GPU - const bool data_is_int64 = node->InputDefs()[0]->Type()->find("int64") != std::string::npos; + const bool data_is_int64 = + node->InputDefs()[0]->Type()->find("int64") != std::string::npos; return data_is_int64; } return false; @@ -506,9 +603,12 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto x_data_type = node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); auto y_data_type = node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - // currently both inputs with int32 are not supported and also both input datatypes should be same - const bool A_is_int32 = node->InputDefs()[0]->Type()->find("int32") != std::string::npos; - const bool B_is_int32 = node->InputDefs()[1]->Type()->find("int32") != std::string::npos; + // currently both inputs with int32 are not supported + // and also both input datatypes should be same + const bool A_is_int32 = + node->InputDefs()[0]->Type()->find("int32") != std::string::npos; + const bool B_is_int32 = + node->InputDefs()[1]->Type()->find("int32") != std::string::npos; if ((A_is_int32 && B_is_int32) || (x_data_type != y_data_type)) return true; } @@ -589,11 +689,13 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto slope = node->InputDefs()[1]; // PRelu slope has to be an initializer or needs to come from a constant node - if (initializers.count(slope->Name())) + if (initializers.count(slope->Name())) { return false; - else { - for (auto input_node = node->InputNodesBegin(); input_node != node->InputNodesEnd(); ++input_node) { - if (GetInputCount(this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) + } else { + for (auto input_node = node->InputNodesBegin(); + input_node != node->InputNodesEnd(); ++input_node) { + if (GetInputCount( + this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) return false; } } @@ -603,12 +705,12 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"PRelu", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); // Reshape op with empty dim is Rejected for Myriad - //[TODO] Is this condition required anymore with Myriad removed? + // [TODO] Is this condition required anymore with Myriad removed? if (shape != nullptr) { for (const auto& dim : input_arg->Shape()->dim()) { if (utils::HasDimValue(dim) && dim.dim_value() == 0) @@ -638,7 +740,8 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { // INT32 dataype is not supported as input for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -650,9 +753,11 @@ void DataOps::populate_op_mode_supported() { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { - auto output_data_type = node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + auto output_data_type = + node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // If the output of ScatterND op is BOOL, it is rejected for GPU. - if (output_data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) + if (output_data_type == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) return true; } return false; @@ -666,7 +771,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // If the Input of Shrink op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) return true; } return false; @@ -714,10 +820,11 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Squeeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze - // If axes is an input, then we cannot produce a static graph. Conversion fails in convert_function_to_cnn_network. + // If axes is an input, then we cannot produce a static graph. + // Conversion fails in convert_function_to_cnn_network. for (size_t i = 0; i < node->InputDefs().size(); i++) { if (node->InputDefs()[i]->Name() == "axes") { return true; @@ -728,14 +835,15 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); if (upsample_attr.count("scales") > 0) { auto& upsample_arg = upsample_attr.at("scales"); auto float_size = upsample_arg.floats_size(); - if (float_size > 2 && (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { + if (float_size > 2 && + (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { return true; } } @@ -750,9 +858,12 @@ void DataOps::populate_op_mode_supported() { } } // x_arg supports only float, int8 and float16 type - if ((x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { + if ((x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { return false; } else { return true; @@ -849,9 +960,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } else { auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("VPUX") != std::string::npos || device_id_.find("HETERO") != std::string::npos || + if (device_id_.find("NPU") != std::string::npos || device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { - for (auto const& var : supported_types_vpu_) { + for (auto const& var : supported_types_npu_) { if ((var.first <= version_id_) && (var.second == dtype)) { return true; @@ -1079,7 +1190,9 @@ bool DataOps::node_is_supported(const std::mapsecond.find(optype) == opset->second.end() && op_fun == ops_supported_as_function.end()) { #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "The operator is not available in OpenVINO ngraph operators list nor the operator is a special ONNX function" << std::endl; + std::cout << "The operator is not available in OpenVINO ngraph operators list" + << "nor the operator is a special ONNX function" + << std::endl; } #endif return false; @@ -1095,10 +1208,12 @@ std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_setForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, bool is_input) { - if(is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { + graph_viewer_.GetNode(node_idx)->ForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, + bool is_input) { + if (is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { ng_required_initializers.insert(node_arg.Name()); - } }, true); + } }, + true); } else { unsupported_nodes_idx.push_back(node_idx); } @@ -1110,7 +1225,8 @@ bool DataOps::IsOpSupportedOnlyInModel(std::string name) { return ops_supported_only_in_model.find(name) != ops_supported_only_in_model.end(); } -bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node) { +bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, + const Node* node) { if (node->OpType() == "Reshape") { const auto& shape_arg = node->InputDefs()[1]; if (ng_required_initializers.find(shape_arg->Name()) == ng_required_initializers.end()) { @@ -1119,15 +1235,20 @@ bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& } else if (node->OpType() == "Expand") { // nGraph only supports constant shape input values const auto& output = node->OutputDefs()[0]; - if (output->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) + if (output->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) return true; } else if (node->OpType() == "RoiAlign") { using onnx_dtype = ONNX_NAMESPACE::TensorProto_DataType; - onnx_dtype input_0_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_1_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_2_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype output_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_0_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_1_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_2_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype output_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); if ((input_0_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || (input_1_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index cc968d02ea644..a5aa3f825602c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -3,6 +3,11 @@ #pragma once #include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -47,7 +52,7 @@ class DataOps { std::multimap op_list_; std::vector subgraph_supported_; std::vector no_dimension_supported_; - std::set supported_types_vpu_; + std::set supported_types_npu_; std::set supported_types_cpu_; std::set supported_types_gpu_; std::set supported_types_initializer_; @@ -64,14 +69,16 @@ class DataOps { const NodeIndex node_idx); public: - DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { + DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) + : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { populate_op_mode_supported(); populate_types_supported(); } virtual std::vector GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers); virtual bool IsOpSupportedOnlyInModel(std::string name); - virtual bool SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node); + virtual bool SpecialConditionForClusterSizeOne( + std::unordered_set& ng_required_initializers, const Node* node); virtual bool DoNotOmitSubGraph(const std::string& name); virtual bool InsertNode(const std::string& name); VersionNum GetVersion() const { return version_id_; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index be509b6743621..74369d39b9a24 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/shared_library/provider_api.h" +#include "utils.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245 5208) @@ -113,7 +114,8 @@ std::map> GetNgSupportedOps(const int onnx_op * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph */ std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes) { +GetPartitionedClusters(const std::vector& topological_order, + const std::vector& unsupported_nodes) { std::vector> ng_clusters; auto prev = topological_order.begin(); @@ -140,7 +142,10 @@ GetPartitionedClusters(const std::vector& topological_order, const st return ng_clusters; } -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster) { +void IdentifyConnectedNodes(const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster) { if (std::find(cluster.begin(), cluster.end(), curr_node_index) == cluster.end()) return; @@ -205,7 +210,8 @@ void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const auto& ext_node = graph_viewer.GetNode((*it).Index()); if (std::find(cluster.begin(), cluster.end(), ext_node->Index()) == cluster.end()) { - // Node is external to this_cluster. Search through its inputs to find the output that is generated by this_cluster. + // Node is external to this_cluster. Search through its inputs to + // find the output that is generated by this_cluster. std::set ext_node_inputs; ext_node->ForEachDef( [&ext_node_inputs](const NodeArg& arg, bool is_input) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 70f6954ea991c..c256cde97956e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -1,5 +1,15 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -18,9 +28,14 @@ int GetOnnxOpSet(const GraphViewer& graph_viewer); std::map> GetNgSupportedOps(const int onnx_opset); std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes); - -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster); +GetPartitionedClusters( + const std::vector& topological_order, const std::vector& unsupported_nodes); + +void IdentifyConnectedNodes( + const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster); std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index ccbc1acaa2f9e..3e17fb157b160 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -1,16 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include + #include "core/providers/common.h" +#include "core/util/qmath.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class BatchNormOpBuilder : public BaseOpBuilder { @@ -18,9 +22,446 @@ class BatchNormOpBuilder : public BaseOpBuilder { BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder); + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + std::pair CheckMinMax(float rmin, float rmax) const { + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + + // Both QNN and ORT require the range to include 0.0f + rmin = std::min(rmin, 0.0f); + rmax = std::max(rmax, 0.0f); + + return std::make_pair(rmin, rmax); + } + + template + Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax) const { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point) const { + std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + + scale = (rmax - rmin) / (qmax - qmin); + const float initial_zero_point = qmin - (rmin / scale); + zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + // To match QNN quantization definition + zero_point = 0 - zero_point; + return Status::OK(); + } + + inline Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value, + int& offset) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int8_t); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int16_t); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int32_t); + break; + } + case QNN_DATATYPE_INT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int64_t); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint8_t); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint16_t); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint32_t); + break; + } + case QNN_DATATYPE_UINT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint64_t); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(float); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status AssertUnpackedTensorSize(const Qnn_DataType_t qnn_data_type, + const uint32_t channel, + const size_t raw_ptr_length) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_FLOAT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(float)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status ConvertToRawOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const std::vector& double_tensor, + std::vector& raw_tensor) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(int8_t)); + int8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(int16_t)); + int16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(int32_t)); + int32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(int64_t)); + int64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(uint8_t)); + uint8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(uint16_t)); + uint16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(uint32_t)); + uint32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(uint64_t)); + uint64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_FLOAT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(float)); + float* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UFIXED_POINT_32: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_32: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline double Dequantize(const OnnxInputInfo& info, + const double quant_value) const { + auto offset = static_cast(info.quant_param.scaleOffsetEncoding.offset); + auto scale = static_cast(info.quant_param.scaleOffsetEncoding.scale); + return (quant_value + offset) * scale; + } + + template + inline T Saturate(const T qmax, + const T qmin, + const T quant_value) const { + if (quant_value > qmax) { + return qmax; + } else if (quant_value < qmin) { + return qmin; + } else { + return quant_value; + } + } + + inline Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value) const { + int qmin = 0; + int qmax = 255; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); + return Status::OK(); + } + + Status PreprocessMean(const OnnxInputInfo& mean_info, + const bool is_npu_backend, + const uint8_t* mean_raw_ptr, + const size_t mean_raw_ptr_length, + std::vector& mean_out) const { + // tensor length (channel) + uint32_t channel = mean_info.shape[0]; + mean_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double mean_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); + mean_out[i] = (is_npu_backend) ? Dequantize(mean_info, mean_value) : mean_value; + } + return Status::OK(); + } + + Status PreprocessStd(const OnnxInputInfo& var_info, + const bool is_npu_backend, + const uint8_t* var_raw_ptr, + const size_t var_raw_ptr_length, + const float epsilon, + std::vector& std_out) const { + // tensor length (channel) + uint32_t channel = var_info.shape[0]; + std_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double var_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); + std_out[i] = (is_npu_backend) ? Dequantize(var_info, var_value) : var_value; + std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); + } + return Status::OK(); + } + + Status PreprocessScale(const OnnxInputInfo& scale_info, + const bool is_npu_backend, + const uint8_t* scale_raw_ptr, + const size_t scale_raw_ptr_length, + const std::vector& std_double_tensor, + double& rmax, + double& rmin, + std::vector& scale_out) const { + // tensor length (channel) + uint32_t channel = scale_info.shape[0]; + scale_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double scale_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); + scale_out[i] = (is_npu_backend) ? Dequantize(scale_info, scale_value) : scale_value; + scale_out[i] = scale_out[i] / std_double_tensor[i]; + rmax = std::max(rmax, scale_out[i]); + rmin = std::min(rmin, scale_out[i]); + } + return Status::OK(); + } + + Status PreprocessBias(const OnnxInputInfo& bias_info, + const bool is_npu_backend, + const uint8_t* bias_raw_ptr, + const size_t bias_raw_ptr_length, + const std::vector& scale_double_tensor, + const std::vector& mean_double_tensor, + double& rmax, + double& rmin, + std::vector& bias_out) const { + // tensor length (channel) + uint32_t channel = bias_info.shape[0]; + bias_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double bias_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); + bias_out[i] = (is_npu_backend) ? Dequantize(bias_info, bias_value) : bias_value; + bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); + rmax = std::max(rmax, bias_out[i]); + rmin = std::min(rmin, bias_out[i]); + } + return Status::OK(); + } + + Status Postprocess(const OnnxInputInfo& info, + const bool is_npu_backend, + const std::vector& double_tensor, + const double rmax, + const double rmin, + Qnn_QuantizeParams_t& quant_param, + std::vector& raw_tensor) const { + if (is_npu_backend) { + raw_tensor.resize(double_tensor.size()); + float scale = 0.0f; + int zero_point = 0; + ORT_RETURN_IF_ERROR(GetQuantParams(static_cast(rmin), + static_cast(rmax), + info.qnn_data_type, + scale, + zero_point)); + quant_param = QNN_QUANTIZE_PARAMS_INIT; + utils::InitializeQuantizeParam(quant_param, true, scale, zero_point); + for (size_t i = 0; i < double_tensor.size(); ++i) { + // onnx only supports 8 bits quantization + int quant_value_int = 0; + ORT_RETURN_IF_ERROR(Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); + if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + raw_tensor[i] = static_cast(quant_value_int); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + int8_t quant_value = static_cast(quant_value_int); + raw_tensor[i] = *reinterpret_cast(&quant_value); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); + } + } + } else { + ORT_RETURN_IF_ERROR(ConvertToRawOnQnnDataType(info.qnn_data_type, double_tensor, raw_tensor)); + } + return Status::OK(); + } }; // BatchNorm is sensitive with data layout, no special validation so far @@ -34,11 +475,6 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, // Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } else { - NodeAttrHelper node_helper(node_unit); - const float default_epsilon = 1e-05f; - const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. - ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon."); - const auto& inputs = node_unit.Inputs(); ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); @@ -56,11 +492,16 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector scale_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[1].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic scale."); ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); std::vector bias_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[2].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic bias."); + ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, "QNN BatchNorm input 2 (bias) must have 1D shape [channel]."); @@ -68,13 +509,15 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic mean."); std::vector var_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic var."); ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); } @@ -82,6 +525,134 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + ORT_UNUSED_PARAMETER(logger); + + const auto& inputs = node_unit.Inputs(); + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + // + // Input 0 + // + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + // + // Input 1: scale + // Input 2: bias + // QNN only accept 3 input. We need to first combine mean and variance into scale and bias. + // + { + const std::string& scale_name = inputs[1].node_arg.Name(); + const std::string& bias_name = inputs[2].node_arg.Name(); + OnnxInputInfo var_info = {}; + OnnxInputInfo mean_info = {}; + OnnxInputInfo scale_info = {}; + OnnxInputInfo bias_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], scale_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], bias_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[3], mean_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[4], var_info)); + + // scale, bias, mean, and var must be initializers + ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); + ORT_RETURN_IF_NOT(bias_info.is_initializer, "bias must be initializers"); + ORT_RETURN_IF_NOT(mean_info.is_initializer, "mean must be initializers"); + ORT_RETURN_IF_NOT(var_info.is_initializer, "var must be initializers"); + + std::vector scale_unpacked_tensor; + std::vector bias_unpacked_tensor; + std::vector var_unpacked_tensor; + std::vector mean_unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*scale_info.initializer_tensor, scale_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*bias_info.initializer_tensor, bias_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*mean_info.initializer_tensor, mean_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*var_info.initializer_tensor, var_unpacked_tensor)); + + std::vector mean_double_tensor; + std::vector std_double_tensor; + std::vector scale_double_tensor; + std::vector bias_double_tensor; + + NodeAttrHelper node_helper(node_unit); + const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. + + double scale_rmax = std::numeric_limits::min(); + double scale_rmin = std::numeric_limits::max(); + double bias_rmax = std::numeric_limits::min(); + double bias_rmin = std::numeric_limits::max(); + + // Calculate and convert new scale, new bias, mean and std to double array (may be dequantized) + ORT_RETURN_IF_ERROR(PreprocessMean(mean_info, + is_npu_backend, + mean_unpacked_tensor.data(), + mean_unpacked_tensor.size(), + mean_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessStd(var_info, + is_npu_backend, + var_unpacked_tensor.data(), + var_unpacked_tensor.size(), + epsilon, + std_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessScale(scale_info, + is_npu_backend, + scale_unpacked_tensor.data(), + scale_unpacked_tensor.size(), + std_double_tensor, + scale_rmax, + scale_rmin, + scale_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessBias(bias_info, + is_npu_backend, + bias_unpacked_tensor.data(), + bias_unpacked_tensor.size(), + scale_double_tensor, + mean_double_tensor, + bias_rmax, + bias_rmin, + bias_double_tensor)); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { + std::vector scale_raw_tensor; + Qnn_QuantizeParams_t scale_quant_param = scale_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(scale_info, + is_npu_backend, + scale_double_tensor, + scale_rmax, + scale_rmin, + scale_quant_param, + scale_raw_tensor)); + Qnn_TensorType_t scale_tensor_type = GetInputTensorType(qnn_model_wrapper, scale_name); + QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, scale_quant_param, + std::move(scale_info.shape), std::move(scale_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(scale_name); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(bias_name)) { + std::vector bias_raw_tensor; + Qnn_QuantizeParams_t bias_quant_param = bias_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(bias_info, + is_npu_backend, + bias_double_tensor, + bias_rmax, + bias_rmin, + bias_quant_param, + bias_raw_tensor)); + Qnn_TensorType_t bias_tensor_type = GetInputTensorType(qnn_model_wrapper, bias_name); + QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, bias_quant_param, + std::move(bias_info.shape), std::move(bias_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(bias_name); + } + + return Status::OK(); +} + void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 943dc8128a133..d3aafcbecd322 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -404,10 +404,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } } - if (num_of_partitions > 1) { - ORT_ENFORCE(!context_cache_enabled_, "Only support single partition for context cache feature."); - } - const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions, ", number of nodes in the graph: ", num_nodes_in_graph, ", number of nodes supported by QNN: ", num_of_supported_nodes); @@ -485,7 +481,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists(); if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); + ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); // Load and execute from cached context if exist ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer, @@ -509,7 +505,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); if (context_cache_enabled_ && !is_qnn_ctx_model) { - ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); + ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(), diff --git a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh index 5830c9dd0bf27..f87b436d04a17 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh @@ -58,7 +58,7 @@ auto GetCKSoftmaxTypeStringAndOps() { auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, alpha, beta, params->input, params->output, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 463c1cf0d2ea6..c0b7d4722d3e4 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -173,17 +173,17 @@ class RocmKernel : public OpKernel { return provider_->PerThreadDefaultMiopenHandle(); } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); + return gpu_data_transfer->CopyTensorAsync(src, dst, stream); + } + protected: template inline const T* GetConstOnes(size_t count, hipStream_t stream) const { return provider_->template GetConstOnes(count, stream); } - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh index 86d023886cfaf..2518f45e0995e 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -61,7 +61,7 @@ auto GetCKGemmTypeStringAndOps() { params->lda, params->ldb, params->ldc, nop, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -164,7 +164,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); auto nop = Nop{}; auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, @@ -174,7 +174,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() { params->batch, nop, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index d5f9de26ada22..b9c0cdcc1c341 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -221,7 +221,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != HIPBLAS_STATUS_SUCCESS, - "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported (", params->Signature(), ")"); + "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported"); IAllocatorUniquePtr workspace_buffer; if (workspace_size > 0) { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h index 8e894e63c5de1..a391d1af8868c 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h @@ -168,8 +168,7 @@ auto GetRocBlasGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; @@ -238,8 +237,7 @@ auto GetRocBlasBatchedGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; @@ -308,8 +306,7 @@ auto GetRocBlasStridedBatchedGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index ef1f0bf9f8d0e..a1fc67ff60b6f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -824,6 +824,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { cache_path_ = info.engine_cache_path; } + // use a more global cache if given + if (timing_cache_enable_) { + if (!info.timing_cache_path.empty()) { + global_cache_path_ = info.timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } engine_decryption_enable_ = info.engine_decryption_enable; if (engine_decryption_enable_) { engine_decryption_lib_path_ = info.engine_decryption_lib_path; @@ -928,6 +936,15 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path"; } } + if (timing_cache_enable_) { + std::string timing_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCachePath); + // use a more global cache if given + if (!timing_cache_path.empty()) { + global_cache_path_ = timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable); if (!engine_decryption_enable_env.empty()) { @@ -1019,6 +1036,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv throw std::runtime_error("Failed to create directory " + cache_path_); } } + if (!global_cache_path_.empty() && !fs::is_directory(global_cache_path_)) { + if (!fs::create_directory(global_cache_path_)) { + throw std::runtime_error("Failed to create directory " + global_cache_path_); + } + } { auto lock = GetApiLock(); runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); @@ -1104,6 +1126,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_dump_subgraphs: " << dump_subgraphs_ << ", trt_engine_cache_enable: " << engine_cache_enable_ << ", trt_cache_path: " << cache_path_ + << ", trt_global_cache_path: " << global_cache_path_ << ", trt_engine_decryption_enable: " << engine_decryption_enable_ << ", trt_engine_decryption_lib_path: " << engine_decryption_lib_path_ << ", trt_force_sequential_engine_build: " << force_sequential_engine_build_ @@ -2199,7 +2222,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectornode_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, - force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, + global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics}; *state = p.release(); return 0; @@ -2460,7 +2483,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector runtime_ = nullptr; OrtMutex tensorrt_mu_; int device_id_; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index cb7a568d09130..3ead33f9131d9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -25,7 +25,7 @@ constexpr const char* kDLAEnable = "trt_dla_enable"; constexpr const char* kDLACore = "trt_dla_core"; constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; -constexpr const char* kCachePath = "trt_engine_cache_path"; +constexpr const char* kEngineCachePath = "trt_engine_cache_path"; constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; @@ -33,7 +33,8 @@ constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine constexpr const char* kContextMemorySharingEnable = "trt_context_memory_sharing_enable"; constexpr const char* kLayerNormFP32Fallback = "trt_layer_norm_fp32_fallback"; constexpr const char* kTimingCacheEnable = "trt_timing_cache_enable"; -constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache_match"; +constexpr const char* kTimingCachePath = "trt_timing_cache_path"; +constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache"; constexpr const char* kDetailedBuildLog = "trt_detailed_build_log"; constexpr const char* kBuildHeuristics = "trt_build_heuristics_enable"; constexpr const char* kSparsityEnable = "trt_sparsity_enable"; @@ -76,13 +77,14 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) - .AddAssignmentToReference(tensorrt::provider_option_names::kCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) @@ -115,7 +117,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, - {tensorrt::provider_option_names::kCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, + {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, @@ -123,6 +125,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, + {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, @@ -142,7 +145,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); - const std::string kCachePath_ = empty_if_null(info.trt_engine_cache_path); + const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); + const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); @@ -164,13 +168,14 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, - {tensorrt::provider_option_names::kCachePath, kCachePath_}, + {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, + {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, @@ -204,6 +209,27 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options if (provider_options == nullptr) { return; } + auto copy_string_if_needed = [&](std::string& s_in) { + if (string_copy) { + char* dest = nullptr; + auto str_size = s_in.size(); + if (str_size == 0) { + return (const char*)nullptr; + } else { + dest = new char[str_size + 1]; +#ifdef _MSC_VER + strncpy_s(dest, str_size + 1, s_in.c_str(), str_size); +#else + strncpy(dest, s_in.c_str(), str_size); +#endif + dest[str_size] = '\0'; + return (const char*)dest; + } + } else { + return s_in.c_str(); + } + }; + TensorrtExecutionProviderInfo internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options); auto& trt_provider_options_v2 = *reinterpret_cast(provider_options); trt_provider_options_v2.device_id = internal_options.device_id; @@ -220,24 +246,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.int8_calibration_table_name.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_int8_calibration_table_name = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); -#else - strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_int8_calibration_table_name = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_int8_calibration_table_name = internal_options.int8_calibration_table_name.c_str(); - } + trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); trt_provider_options_v2.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table; trt_provider_options_v2.trt_dla_enable = internal_options.dla_enable; @@ -245,45 +254,12 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs; trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.engine_cache_path.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_engine_cache_path = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.engine_cache_path.c_str(), str_size); -#else - strncpy(dest, internal_options.engine_cache_path.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_engine_cache_path = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_engine_cache_path = internal_options.engine_cache_path.c_str(); - } + trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); + trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path); trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.engine_decryption_lib_path.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_engine_decryption_lib_path = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.engine_decryption_lib_path.c_str(), str_size); -#else - strncpy(dest, internal_options.engine_decryption_lib_path.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_engine_decryption_lib_path = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_engine_decryption_lib_path = internal_options.engine_decryption_lib_path.c_str(); - } + trt_provider_options_v2.trt_engine_decryption_lib_path = copy_string_if_needed(internal_options.engine_decryption_lib_path); trt_provider_options_v2.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build; trt_provider_options_v2.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable; @@ -296,100 +272,11 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_builder_optimization_level = internal_options.builder_optimization_level; trt_provider_options_v2.trt_auxiliary_streams = internal_options.auxiliary_streams; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.tactic_sources.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_tactic_sources = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.tactic_sources.c_str(), str_size); -#else - strncpy(dest, internal_options.tactic_sources.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_tactic_sources = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_tactic_sources = internal_options.tactic_sources.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.extra_plugin_lib_paths.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_extra_plugin_lib_paths = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.extra_plugin_lib_paths.c_str(), str_size); -#else - strncpy(dest, internal_options.extra_plugin_lib_paths.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_extra_plugin_lib_paths = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_extra_plugin_lib_paths = internal_options.extra_plugin_lib_paths.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.profile_min_shapes.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_profile_min_shapes = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.profile_min_shapes.c_str(), str_size); -#else - strncpy(dest, internal_options.profile_min_shapes.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_profile_min_shapes = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_profile_min_shapes = internal_options.profile_min_shapes.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.profile_max_shapes.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_profile_max_shapes = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.profile_max_shapes.c_str(), str_size); -#else - strncpy(dest, internal_options.profile_max_shapes.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_profile_max_shapes = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_profile_max_shapes = internal_options.profile_max_shapes.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.profile_opt_shapes.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_profile_opt_shapes = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.profile_opt_shapes.c_str(), str_size); -#else - strncpy(dest, internal_options.profile_opt_shapes.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_profile_opt_shapes = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_profile_opt_shapes = internal_options.profile_opt_shapes.c_str(); - } + trt_provider_options_v2.trt_tactic_sources = copy_string_if_needed(internal_options.tactic_sources); + trt_provider_options_v2.trt_extra_plugin_lib_paths = copy_string_if_needed(internal_options.extra_plugin_lib_paths); + trt_provider_options_v2.trt_profile_min_shapes = copy_string_if_needed(internal_options.profile_min_shapes); + trt_provider_options_v2.trt_profile_max_shapes = copy_string_if_needed(internal_options.profile_max_shapes); + trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes); trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 61a6bf08211be..b16543aa3d7dd 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -38,6 +38,7 @@ struct TensorrtExecutionProviderInfo { bool context_memory_sharing_enable{false}; bool layer_norm_fp32_fallback{false}; bool timing_cache_enable{false}; + std::string timing_cache_path{""}; bool force_timing_cache{false}; bool detailed_build_log{false}; bool build_heuristics_enable{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index d7e13df000272..426584553f349 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -103,6 +103,7 @@ struct Tensorrt_Provider : Provider { info.context_memory_sharing_enable = options.trt_context_memory_sharing_enable != 0; info.layer_norm_fp32_fallback = options.trt_layer_norm_fp32_fallback != 0; info.timing_cache_enable = options.trt_timing_cache_enable != 0; + info.timing_cache_path = options.trt_timing_cache_path == nullptr ? "" : options.trt_timing_cache_path; info.force_timing_cache = options.trt_force_timing_cache != 0; info.detailed_build_log = options.trt_detailed_build_log != 0; info.build_heuristics_enable = options.trt_build_heuristics_enable != 0; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 1e0af51567ca0..af3293dd3d92c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -49,9 +49,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, NodeAttrHelper helper(node); const auto group = helper.Get("group", static_cast(1)); const auto& input_defs = node.InputDefs(); - const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - const auto& weight_shape = weight_tensor.dims(); - + std::vector weight_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); options.set("strides", emscripten::val::array(strides)); options.set("dilations", emscripten::val::array(dilations)); options.set("groups", group); @@ -278,25 +277,28 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d is supported."; + // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. + // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 + if (device_type == WebnnDeviceType::CPU) { + if (Contains(initializers, weight_name)) { + const auto& tensor = *initializers.at(weight_name); + if (tensor.dims().size() != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() + << " Only conv 2d is supported."; + return false; + } + } else { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; return false; } - } else { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; } - return true; } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 26c739e9a1ce1..02a3d16b5b64f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -26,11 +26,7 @@ WebNNExecutionProvider::WebNNExecutionProvider( ORT_THROW("Failed to get ml from navigator."); } emscripten::val context_options = emscripten::val::object(); - // Currently WebNN implementation in Chromium temporarily reuses the MLContextOptions - // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType - // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at - // https://github.com/webmachinelearning/webnn/issues/302. - context_options.set("devicePreference", emscripten::val(webnn_device_flags)); + context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 5f1d5036e8310..b827c28f129b1 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -25,10 +25,11 @@ #if !defined(ORT_MINIMAL_BUILD) static constexpr uint32_t min_ort_version_with_optional_io_support = 8; static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; +static constexpr uint32_t min_ort_version_with_custom_version = 17; #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -static constexpr uint32_t min_ort_version_with_compute_v2_support = 17; +static constexpr uint32_t min_ort_version_with_compute_v2_support = 16; static constexpr uint32_t min_ort_version_with_shape_inference = 17; #endif @@ -698,8 +699,19 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust KernelDefBuilder def_builder; def_builder.SetName(op->GetName(op)) - .SetDomain(domain) - .SinceVersion(1); + .SetDomain(domain); + + if (op->version >= min_ort_version_with_custom_version) { + if (op->GetStartVersion && op->GetEndVersion) { + def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); + } else if (op->GetStartVersion) { + def_builder.SinceVersion(op->GetStartVersion(op)); + } else { + def_builder.SinceVersion(1); + } + } else { + def_builder.SinceVersion(1); + } // GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions // to work with newer versions (> 12) of the ORT binary. @@ -820,7 +832,11 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); } schema.SetDomain(domain); - schema.SinceVersion(1); + if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) { + schema.SinceVersion(op->GetStartVersion(op)); + } else { + schema.SinceVersion(1); + } schema.AllowUncheckedAttributes(); if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d307f79c372ed..df4dd55417755 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1432,7 +1432,7 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_vpu_fast_compile"] = legacy_ov_options->enable_vpu_fast_compile; + ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; @@ -1931,6 +1931,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor if (ptr != nullptr) { delete[] ptr->trt_int8_calibration_table_name; delete[] ptr->trt_engine_cache_path; + delete[] ptr->trt_timing_cache_path; delete[] ptr->trt_engine_decryption_lib_path; delete[] ptr->trt_tactic_sources; delete[] ptr->trt_extra_plugin_lib_paths; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a72f563601512..2027b592326df 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -479,7 +479,7 @@ std::unique_ptr CreateExecutionProviderInstance( // So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance. // (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance // and TRT EP instance, so it won't be released.) - std::string calibration_table, cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; + std::string calibration_table, cache_path, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -623,6 +623,13 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_timing_cache_path") { + if (!option.second.empty()) { + timing_cache_path = option.second; + params.trt_timing_cache_path = timing_cache_path.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_path' should be a path string i.e. 'cache_folder/'.\n"); + } } else if (option.first == "trt_force_timing_cache") { if (option.second == "True" || option.second == "true") { params.trt_force_timing_cache = true; @@ -806,10 +813,10 @@ std::unique_ptr CreateExecutionProviderInstance( if (option.first == "device_type") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "enable_vpu_fast_compile") { + } else if (option.first == "enable_npu_fast_compile") { if (!(option.second == "True" || option.second == "true" || option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_vpu_fast_compile: ", option.second); + ORT_THROW("Invalid value passed for enable_npu_fast_compile: ", option.second); } OV_provider_options_map[option.first] = option.second; } else if (option.first == "enable_opencl_throttling") { @@ -1214,14 +1221,14 @@ void addGlobalMethods(py::module& m) { #ifdef ENABLE_ATEN m.def("register_aten_op_executor", - [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_tensor_argument_address_int, aten_op_executor_address_int; + [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_cpu_argument_address_int, aten_op_executor_address_int; ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); + ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int)); ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); + void* p_is_cpu_argument = reinterpret_cast(is_cpu_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor); }); #endif } diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 5bb6bcc38b6fe..a5bcbce89bac6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -60,11 +60,11 @@ struct OrtStatus { #elif OPENVINO_CONFIG_GPU_FP16 #define BACKEND_OPENVINO "-OPENVINO_GPU_FP16" -#elif OPENVINO_CONFIG_VPUX_FP16 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_FP16" +#elif OPENVINO_CONFIG_NPU_FP16 +#define BACKEND_OPENVINO "-OPENVINO_NPU_FP16" -#elif OPENVINO_CONFIG_VPUX_U8 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_U8" +#elif OPENVINO_CONFIG_NPU_U8 +#define BACKEND_OPENVINO "-OPENVINO_NPU_U8" #elif OPENVINO_CONFIG_MULTI #define BACKEND_OPENVINO "-OPENVINO_MULTI" diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 272727a9f5375..9b68aef57656e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -198,7 +198,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, "Gelu": self._infer_Gelu, "GemmFastGelu": self._infer_GemmFastGelu, + "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, "MultiHeadAttention": self._infer_MultiHeadAttention, @@ -2317,6 +2319,9 @@ def _infer_QuickGelu(self, node): # noqa: N802 def _infer_GemmFastGelu(self, node): # noqa: N802 self._compute_matmul_shape(node) + def _infer_GemmFloat8(self, node): # noqa: N802 + self._compute_matmul_shape(node) + def _infer_LayerNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node) if len(node.output) > 1: @@ -2372,6 +2377,11 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + def _infer_BiasSplitGelu(self, node): # noqa: N802 input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 4228c892d03ae..b32ae64c5b0c0 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1275,7 +1275,7 @@ def find_past_seq_len_usage(subg: GraphProto): def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): - # Replace model input for past sequence length + # Add model input for past sequence length new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) model.model.graph.input.append(new_input) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 3c5029ac5752f..44d15b619ec7a 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -427,6 +427,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 2, 1, 0, 0, 0], ) + attn_mask_nodes_5 = self.model.match_parent_path( + add_qk, + ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_6 = self.model.match_parent_path( + add_qk, + ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] @@ -439,6 +449,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif attn_mask_nodes_4 is not None: # Reshape from (B,1,S,T) to (B,N,S,T) add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) + elif attn_mask_nodes_5 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_5[0].output[0] + elif attn_mask_nodes_6 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_6[0].output[0] else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index de17f195c99cc..50703b9c17e03 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,6 +1,6 @@ import logging from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union import numpy import torch @@ -229,7 +229,7 @@ def __del__(self): del self.io_binding del self.ort_session - def allocate_buffers(self, shape_dict: Dict[str, tuple]): + def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): """Allocate tensors for I/O Binding""" if self.enable_cuda_graph: for name, shape in shape_dict.items(): diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 6057b46667fe6..9619e6cb52a91 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -1,5 +1,18 @@ # LLaMA-2 +## Prerequisites + +Please note the package versions needed for using LLaMA-2 in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running LLaMA-2 on CPU +- `requirements-cuda.txt` + - For running LLaMA-2 on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements-quant.txt` + - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements.txt` + - Package versions needed in each of the above files + ## Exporting LLaMA-2 There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example). @@ -40,7 +53,7 @@ Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onn ### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) -Note that this will produce two ONNX models whereas the above two options produce one ONNX model. +Note that this may produce two ONNX models with older Optimum versions. The above two options produce one ONNX model and installing Optimum from source will now produce one ONNX model. First, log into the Hugging Face CLI in your terminal: @@ -81,7 +94,7 @@ Export for FP32 CUDA $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda ``` Export for FP32 CPU @@ -90,7 +103,7 @@ Export for FP32 CPU $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu ``` Export for FP16 CUDA @@ -105,10 +118,10 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama Export for INT8 CPU (SmoothQuant) ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged ``` Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models. @@ -128,7 +141,7 @@ Export for INT4 CUDA $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda ``` Export for INT4 CPU @@ -137,7 +150,7 @@ Export for INT4 CPU $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` ## Benchmark LLaMA-2 @@ -183,20 +196,7 @@ python3 -m models.llama.benchmark \ --auth ``` -4. Optimum + ONNX Runtime, FP16, export via convert_to_onnx -``` -python3 -m models.llama.benchmark \ - --benchmark-type hf-ort \ - --hf-ort-dir-path ./llama2-7b-fp16/ \ - --model-name meta-llama/Llama-2-7b-hf \ - --precision fp16 \ - --batch-sizes "1 2" \ - --sequence-lengths "8 16" \ - --device cuda \ - --auth -``` - -5. ONNX Runtime, FP32, Microsoft custom export +4. ONNX Runtime, FP32, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -208,7 +208,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -6. ONNX Runtime, FP16, Microsoft custom export +5. ONNX Runtime, FP16, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -220,7 +220,7 @@ python3 -m models.llama.benchmark \ --device cuda ``` -7. ONNX Runtime, FP32, convert_to_onnx +6. ONNX Runtime, FP32, convert_to_onnx ``` python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ @@ -232,7 +232,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -8. ONNX Runtime, FP16, convert_to_onnx +7. ONNX Runtime, FP16, convert_to_onnx ``` python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 976de2abc7c57..245ff3dfe7f9d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,9 +11,8 @@ import onnx import psutil import torch -from benchmark_helper import setup_logger from llama_inputs import ( - convert_inputs_for_ort, + add_io_bindings, get_merged_sample_with_past_kv_inputs, get_msft_sample_inputs, get_sample_inputs, @@ -25,7 +24,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory +from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: - # Set max_seq_len to 2048 for Hugging Face model since that is the default value - # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported - max_seq_len = 2048 + # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) + # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value + # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_seq_len = ( + 2048 + if args.benchmark_type == "ort-msft" + else 16384 + if "codellama" in temp_name + else 4096 + if "llama2" in temp_name + else 2048 + ) if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( @@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) @@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="ort", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -125,26 +140,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, - use_fp16=args.use_fp16, - return_dict=True, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + engine="ort", + return_dict=True, ) elif args.benchmark_type == "ort-msft": @@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=0, seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, split_kv=split_kv, ) @@ -164,26 +164,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=args.sequence_length, seq_len=1, - use_fp16=args.use_fp16, - split_kv=split_kv, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + split_kv=split_kv, ) else: @@ -286,31 +269,50 @@ def time_fn(args, fn, inputs): outputs = fn(inputs) logger.info(outputs) + input_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_inputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + + output_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_outputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + for _ in warmup_range: + input_sync() fn(inputs) + output_sync() # Benchmark - if args.device != "cpu": - torch.cuda.synchronize() - start_time = time.time() - + total_time = 0 bench_range = ( range(args.num_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: + input_sync() + start_time = time.time() + fn(inputs) - if args.device != "cpu": - torch.cuda.synchronize() - end_time = time.time() + output_sync() + end_time = time.time() + + total_time += end_time - start_time # Newline print after trange in order to print metrics on new lines without progress bar on same line if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") - latency = (end_time - start_time) / args.num_runs + latency = total_time / args.num_runs throughput = args.batch_size / latency logger.info(f"Batch Size: {args.batch_size}") @@ -430,7 +432,7 @@ def get_logits(inputs): def run_ort_inference(args, init_inputs, iter_inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -448,28 +450,13 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.past_present_share_buffer: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.past_present_share_buffer and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) - - return io_binding + io_binding, kv_cache_ortvalues = add_io_bindings( + model, inputs, args.device, int(args.device_id), kv_cache_ortvalues + ) + setattr(args, "io_binding", io_binding) # noqa: B010 + return io_binding, kv_cache_ortvalues - return inputs + return inputs, kv_cache_ortvalues def with_io_binding(io_binding): # Inference pass with IO binding @@ -481,9 +468,10 @@ def without_io_binding(inputs): return outputs generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + kv_cache_ortvalues = {} if args.profile: - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") # Turn profiling off to stop appending to log file @@ -493,7 +481,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log @@ -504,12 +492,12 @@ def without_io_binding(inputs): # ORT evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 61d71bc38f4e9..3f05be53c6729 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -716,6 +716,7 @@ def main(): run_torchscript_separate_export(args, l_config, llama) else: run_torchscript_merged_export(args, l_config, llama) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check # Set model paths to store FP32 optimized model decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") @@ -811,13 +812,13 @@ def main(): logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") remove_existing_model(fp_path) - del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" - if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + if args.precision in {Precision.INT8, Precision.FLOAT32} + or (args.precision == Precision.INT4 and args.execution_provider == "cpu") else "fp16" ) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 2652e9f0ca64e..f7a1b05249abf 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,7 +4,7 @@ import torch from transformers import LlamaConfig -from onnxruntime import OrtValue +from onnxruntime import InferenceSession, OrtValue # Get position_ids from attention_mask @@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: + # Shape: (batch_size, 1) position_ids = position_ids[:, -1].unsqueeze(-1) + + # Shape: (batch_size, sequence_length) return position_ids # Inputs for first pass to get initial past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, sequence_length) +# position_ids: (batch_size, sequence_length) def get_sample_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + engine: str = "pt", + return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64) - # position_ids is of shape (batch_size, seq_len) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) position_ids = get_position_ids(attention_mask, use_past_kv=False) + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + if not return_dict: + # For export return (input_ids, attention_mask, position_ids) inputs = { @@ -39,85 +53,192 @@ def get_sample_inputs( # Inputs for subsequent passes with past_key_values +# input_ids: (batch_size, 1) +# attention_mask: (batch_size, past_sequence_length + 1) +# position_ids: (batch_size, 1) +# past_key: (batch_size, num_heads, past_sequence_length, head_size) +# past_value: (batch_size, num_heads, past_sequence_length, head_size) def get_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) - attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs # Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# position_ids: (batch_size, sequence_length) +# past_key: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +# past_value: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length def get_merged_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, past_seq_len: int, + max_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_fp16: # If model has GQA + del inputs["attention_mask"] + inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs -# Create past_key_values -def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool +# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx +def get_msft_sample_inputs( + config: LlamaConfig, + batch_size: int, + past_seq_len: int, + seq_len: int, + max_seq_len: int, + use_fp16: bool, + split_kv: bool, ): + np_dtype = np.float16 if use_fp16 else np.float32 + head_size = config.hidden_size // config.num_attention_heads + + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + + if use_fp16: # If model has GQA + del ort_inputs["attn_mask"] + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +# Create past_key_values +# Each is of shape (batch_size, num_heads, past_sequence_length, head_size) +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool): num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), ) for _ in range(config.num_hidden_layers) ] return past_kv -# Convert list of past_kv to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool): +# Convert list of past_key_values to dict of past_key and past_value +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): past_kv = {} - np_dtype = np.float16 if use_fp16 else np.float32 for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype) - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype) + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv @@ -136,7 +257,7 @@ def convert_inputs_for_ort( if isinstance(v, np.ndarray): ort_inputs[k] = v elif k == "past_key_values": - ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + ort_inputs.update(flatten_past_kv_inputs(v)) elif k == "attention_mask" and use_fp16 and use_buffer_share: # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, # and GQA supports a causal mask by default @@ -146,59 +267,55 @@ def convert_inputs_for_ort( else: ort_inputs[k] = v.detach().cpu().numpy() - # Enable past-present-share-buffer by using device memory directly + # Reshape kv caches if using past-present-share-buffer if use_buffer_share and device != "" and device != "cpu" and device_id > -1: - for k, v in ort_inputs.items(): - new_v = v - # Allocate new buffers with max_sequence_length for GQA - if "cache" in k or "past_key_values" in k: - # Copy v (BxSxPxH) into new_v (BxSxMxH) - batch_size, num_heads, _, head_size = v.shape - new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) - new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v - ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs -# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs( - config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool -): - np_dtype = np.float16 if use_fp16 else np.float32 - head_size = config.hidden_size // config.num_attention_heads - max_seq_len = 2048 +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs - if not split_kv: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } - else: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( - np.int32 - ), - "pos": np.array(past_seq_len, dtype=np.int64), - } - for i in range(config.num_hidden_layers): - ort_inputs.update( - { - f"k_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - f"v_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - } - ) - return ort_inputs +# Add IO bindings for execution providers +def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict): + use_fp16 = False + io_binding = model.io_binding() + + for k, v in ort_inputs.items(): + # Detect if model is in FP16 + if v.dtype == np.float16: + use_fp16 = True + + # Bind OrtValue inputs to device + if use_fp16 and ("cache" in k or "past_key_values" in k): + if k not in kv_cache_ortvalues: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + kv_cache_ortvalues[k] = v_device + else: + kv_cache_ortvalues[k].update_inplace(v) + io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k]) + else: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + + for output in model.get_outputs(): + name = output.name + if use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to past KV cache inputs in order to buffer share + input_name = name.replace("out", "cache").replace("present", "past_key_values") + io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) + else: + io_binding.bind_output(name, device_type=device, device_id=device_id) + + return io_binding, kv_cache_ortvalues diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 6bfcb9b4f290d..c1c5d3c412f2a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -8,6 +8,7 @@ import torch from benchmark_helper import setup_logger from llama_inputs import ( + add_io_bindings, convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, get_sample_inputs, @@ -22,22 +23,24 @@ def get_sequence_lengths(args: argparse.Namespace): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) - max_sequence_length = 2048 + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 return past_sequence_length, curr_sequence_length, max_sequence_length def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity batch_size = 2 - past_sequence_length, sequence_length, _ = get_sequence_lengths(args) + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, batch_size, - sequence_length, - past_sequence_length, + seq_len=sequence_length, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, ) @@ -51,31 +54,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): return inputs -def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): - # Add IO bindings for non-CPU execution providers - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.use_fp16: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.use_fp16 and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) - - return io_binding - - -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict): inputs = get_inputs(args, config) # Run inference with PyTorch @@ -111,12 +90,14 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": - io_binding = add_io_bindings(args, ort_model, inputs) + io_binding, kv_cache_ortvalues = add_io_bindings( + ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues + ) - torch.cuda.synchronize() + io_binding.synchronize_inputs() start_time = time.time() ort_model.run_with_iobinding(io_binding) - torch.cuda.synchronize() + io_binding.synchronize_outputs() end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits @@ -131,17 +112,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = ( - 2e1 - if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path - else 1e-3 - if args.precision == "fp32" - else 5e-1 - ) + tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") + return kv_cache_ortvalues def get_args(argv: List[str]): @@ -250,16 +226,17 @@ def main(argv: List[str] = []): # noqa: B006 use_cache=True, ).to(args.device) + kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) else: # Verify prompt generation in merged model (decoder_model.onnx) args.use_past_kv = False - verify_parity(args, config, llama) + kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt index e06c3ada834b0..3d707fa13e3c8 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt @@ -1,2 +1,2 @@ -r requirements.txt -onnxruntime>=1.17.0 \ No newline at end of file +onnxruntime>=1.16.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index 773680937bd21..b634bcc50f6e4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt # Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.17.0 \ No newline at end of file +onnxruntime-gpu>=1.16.2 \ No newline at end of file diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 9dee6564509d5..8bf7cbf80eb37 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor _C.register_aten_op_executor( - str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 182f2368f5b47..903a394a06ef3 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -154,11 +154,32 @@ class ATenOperatorCache { std::unordered_map, ATenOperator, PairHash> ops_; }; -// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not. -bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) { - const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); - TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); - return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; +const std::unordered_map> kCpuTensorInputsMap = { + {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}}; + +const std::unordered_map> kCpuTensorOutputsMap = { + {"_efficient_attention_forward", {2, 3}}}; + +// Backend uses this function to check if an argument is CPU input or not. +bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { + if (is_input) { + // If the argument is non-tensor type, it's CPU argument. + const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); + TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); + if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) { + return true; + } + } + + std::string full_name = std::string(op_name); + std::string overload_name_str = std::string(overload_name); + if (overload_name_str != "") { + full_name += ("." + overload_name_str); + } + + const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap; + return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() && + cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end(); } void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, @@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t size_t output_index = 0; for (const auto& ret : torch::jit::pop(stack, output_size)) { const auto& tensor = ret.toTensor(); - dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()); + dlpack_outputs[output_index++] = + tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; } } -size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } +size_t is_cpu_argument_address() { return reinterpret_cast(&IsCpuArgument); } size_t execute_aten_operator_address() { return reinterpret_cast(&ExecuteATenOperator); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check."); + m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check."); m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor"); } diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py index 7d5716b85db30..329fba5aa670a 100644 --- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -5,7 +5,7 @@ from onnxruntime.capi import _pybind_state as _C -from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address +from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address def run_once_aten_op_executor(f): @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): - _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) + _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address())) def init_aten_op_executor(): diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc new file mode 100644 index 0000000000000..fefd5722054de --- /dev/null +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { + constexpr int64_t B = 2; + constexpr int64_t C = 16; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + -0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f, + -0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f, + -0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f, + -2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f, + 1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f, + -0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f, + -0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f, + -1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f, + -0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f, + 1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f, + 0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f, + 0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f, + 0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f, + -0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f, + 1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f, + -0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f}; + + std::vector gamma_data = { + 0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f, + 0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f}; + + std::vector beta_data = { + -1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f, + -1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f}; + + std::vector skip_data_nhwc = { + 0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f, + -0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f, + 0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f, + -0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f, + 0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f, + -0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f, + -0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f, + 0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f, + -0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f, + 0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f, + -0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f, + -1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f, + -0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f, + 0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f, + 0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f, + -0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f}; + + std::vector bias_data = { + -0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f, + 1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f}; + + std::vector norm_data_nhwc = { + -1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f, + 1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f, + -0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f, + 0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f, + -1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f, + -0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f, + 2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f, + -1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f, + -1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f, + 1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f, + -0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f, + -0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f, + -0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f, + -0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f, + 0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f, + -1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f, + -1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f, + 1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f, + -0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f, + -2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f, + -1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f, + -0.542480f, 1.990234f}; + + std::vector add_out_data_nhwc = { + -0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f, + 0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f, + -1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f, + -1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f, + 2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f, + -0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f, + -1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f, + 0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f, + -1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f, + 2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f, + -0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f, + -0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f, + -1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f, + 1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f, + 1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f, + 0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array channels_last_values = {-1, 1}; + + for (const int channels_last : channels_last_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 4); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + test.AddInput("skip", dims_nhwc, ToFloat16(skip_data_nhwc)); + test.AddInput("bias", {C}, ToFloat16(bias_data)); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { + constexpr int64_t B = 1; + constexpr int64_t C = 64; + constexpr int64_t H = 1; + constexpr int64_t W = 1; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + 0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f, + -0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f, + 0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f, + -1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f, + -0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f, + -1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f, + -1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f, + -0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f}; + + std::vector gamma_data = { + 0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f, + 0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f, + 0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f, + 1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f, + -0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f, + 0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f, + 0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f, + -0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f}; + + std::vector beta_data = { + -0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f, + 1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f, + -2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f, + -1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f, + -1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f, + -0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f, + -1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f, + -0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f}; + + std::vector skip_data = { + 0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f, + 0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f, + -0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f, + -0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f, + -0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f, + -0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f, + -1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f, + -1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f, + 1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f, + -0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f, + 1.248047f, -2.515625f, 0.080383f, 0.376221f}; + + std::vector norm_data_nhwc = { + 0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f, + 1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f, + -2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f, + -2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f, + -0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f, + -0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f, + -1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f, + -0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f}; + + std::vector add_out_data_nhwc = { + 1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f, + 1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f, + -0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f, + 1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f, + -1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f, + -0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f, + -0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f, + -0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f, + 0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f, + 0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f, + 1.401367f, -3.828125f, 2.292969f, 1.061523f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array has_add_out_values = {true, false}; + std::array skip_dims = {2, 4}; + + constexpr int channels_last = 1; + for (const int skip_dim : skip_dims) { + for (const bool has_add_out : has_add_out_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 8); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + if (skip_dim == 2) { + test.AddInput("skip", {B, C}, ToFloat16(skip_data)); + } else { + test.AddInput("skip", {B, 1, 1, C}, ToFloat16(skip_data)); + } + // no bias + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + + if (has_add_out) { + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + } + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index c38baee39216b..1804c09043c7b 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -4,6 +4,7 @@ #include "core/framework/allocator.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/graph/model.h" +#include "core/graph/node_attr_utils.h" #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" @@ -110,6 +111,70 @@ TEST(TransformerTest, InsertCastAllCPUTest) { } } +TEST(TransformerTest, CastRemovalDoesNotLowerPrecisionTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_float_32; + tensor_float_32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + TypeProto tensor_float_64; + tensor_float_64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_DOUBLE); + onnxruntime::NodeArg n1_def("N1", &tensor_float_64), + n2_def("N2", &tensor_float_32), + n3_def("N3", &tensor_float_64); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))}}; + + graph.AddNode("node1", "Cast", "F64 to F32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "F32 to F64 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting f64 -> f32 -> f64 we should not be optimising away the cast since there is a loss of precision. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + +TEST(TransformerTest, CastRemovalDoesNotRemoveSignednessTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_uint32; + tensor_uint32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_UINT32); + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + onnxruntime::NodeArg n1_def("N1", &tensor_int32), + n2_def("N2", &tensor_uint32), + n3_def("N3", &tensor_int32); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_UINT32))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32))}}; + + graph.AddNode("node1", "Cast", "I32 to UI32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "UI32 to I32 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting i32 -> ui32 -> i32 we should not be optimising away the cast since applying the casts produces a very different result. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + // test that when there are 3 Cast ops in a row we remove the correct ones TEST(TransformerTest, ThreeInARowRemoval) { auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx"); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp new file mode 100644 index 0000000000000..6f06e0f2eead8 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -0,0 +1,208 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_blockq4.cpp + +Abstract: + + Tests for MLAS blockwise int4 quantization and dequantization code. + +--*/ + +#ifndef ORT_MINIMAL_BUILD + +#include "test_util.h" +#include "mlas_q4.h" + +class MlasBlockwiseQdqTest : public MlasTestBase { + private: + MatrixGuardBuffer FpBuf; + MatrixGuardBuffer FpBuf2; + MatrixGuardBuffer InputElements; + MatrixGuardBuffer InputScales; + MatrixGuardBuffer InputOffsets; + MatrixGuardBuffer OutputElements; + MatrixGuardBuffer OutputScales; + MatrixGuardBuffer OutputOffsets; + + void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { + float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); + float* transposed = FpBuf2.GetBuffer(rows * columns, true); + + MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); + + int meta_rows; + int meta_cols; + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + + int q_rows; + int q_cols; + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + + uint8_t* elements = InputElements.GetBuffer(q_rows * q_cols, true); + + int v = 7; + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + elements[idx] = (v1 << 4) | v0; + } + } + + float* scales = InputScales.GetBuffer(meta_rows * meta_cols); + uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + if (zp) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + zp[idx] = (v1 << 4) | v0; + } + } + } + + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); + + MlasTranspose(dequant_buf, transposed, columns, rows); + + uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); + float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); + uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); + + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < rows) { + ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + } + + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r++) { + int idx = c * meta_rows + r; + ASSERT_EQ(o_scales[idx], scales[idx]) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + + if (symmetric) return; + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < meta_rows) { + ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + } + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("BlockQ4"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + Test(1, 20, 32, false, false); + Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + Test(1, 52, 32, false, false); + Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + Test(3, 20, 32, false, false); + Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + Test(3, 52, 32, false, false); + Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + Test(3, 52, 64, false, false); + Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + Test(41, 32 * 9 + 17, 32, false, false); + Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + Test(41, 32 * 9 + 17, 64, false, false); + Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + Test(63, 32 * 15 + 17, 128, false, false); + Test(63, 32 * 15 + 17, 128, false, true); + + Test(256, 256, 32, true, false); + Test(256, 256, 32, true, true); + Test(256, 256, 32, false, false); + Test(256, 256, 32, false, true); + } + + MlasBlockwiseQdqTest() = default; +}; + +template <> +MlasBlockwiseQdqTest* MlasTestFixture::mlas_tester(nullptr); + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index de5431ca4a460..0526ccca5bb4e 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -761,6 +761,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_TSTR("sce_none_weights_expanded")}; std::unordered_set> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests)); + if (enable_cuda) { all_disabled_tests.insert(std::begin(cuda_flaky_tests), std::end(cuda_flaky_tests)); } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b1a04a00e89b1..6d075fec997b5 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -60,7 +60,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" - "\t [OpenVINO only] [enable_vpu_fast_compile]: Optionally enabled to speeds up the model's compilation on VPU device targets.\n" + "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" @@ -72,7 +72,7 @@ namespace perftest { "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_vpu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index a7f0b7584a211..b7a111783fc94 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include "core/session/onnxruntime_session_options_config_keys.h" @@ -100,36 +101,28 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options; Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - - const char* cudnn_conv_algo_search = "cudnn_conv_algo_search"; - const char* default_conv = "DEFAULT"; - const char* benchmarking = "EXHAUSTIVE"; - const char* heuristic = "HEURISTIC"; + std::vector option_keys, option_values; + // used to keep all option keys and value strings alive + std::list buffer; + buffer.emplace_back("cudnn_conv_algo_search"); + option_keys.push_back(buffer.back().c_str()); switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &cudnn_conv_algo_search, &benchmarking, 1)); + buffer.emplace_back("EXHAUSTIVE"); break; case 1: - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &cudnn_conv_algo_search, &heuristic, 1)); + buffer.emplace_back("HEURISTIC"); break; default: - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &cudnn_conv_algo_search, &default_conv, 1)); + buffer.emplace_back("DEFAULT"); break; } + option_values.push_back(buffer.back().c_str()); - const char* do_copy_in_default_stream = "do_copy_in_default_stream"; - if (performance_test_config.run_config.do_cuda_copy_in_separate_stream) { - const char* v = "1"; - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &do_copy_in_default_stream, &v, 1)); - } else { - const char* v = "0"; - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &do_copy_in_default_stream, &v, 1)); - } + buffer.emplace_back("do_copy_in_default_stream"); + option_keys.push_back(buffer.back().c_str()); + buffer.emplace_back(!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); + option_values.push_back(buffer.back().c_str()); #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -148,51 +141,34 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "[ERROR] [CUDA] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); } - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - auto key_p = key.c_str(); - auto value_p = value.c_str(); - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &key_p, &value_p, 1)); + buffer.emplace_back(token.substr(0, pos)); + option_keys.push_back(buffer.back().c_str()); + buffer.emplace_back(token.substr(pos + 1)); + option_values.push_back(buffer.back().c_str()); } + Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, + option_keys.data(), option_values.data(), option_keys.size())); + if (!status.IsOK()) { + OrtAllocator* allocator; + char* options; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + Ort::ThrowOnError(api.GetCUDAProviderOptionsAsString(cuda_options, allocator, &options)); + ORT_THROW("[ERROR] [CUDA] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), + "\nSupported options are:\n", options); + } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #else ORT_THROW("CUDA is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT - int device_id = 0; - int trt_max_partition_iterations = 1000; - int trt_min_subgraph_size = 1; - size_t trt_max_workspace_size = 1 << 30; - bool trt_fp16_enable = false; - bool trt_int8_enable = false; - std::string trt_int8_calibration_table_name = ""; - bool trt_int8_use_native_calibration_table = false; - bool trt_dla_enable = false; - int trt_dla_core = 0; - bool trt_dump_subgraphs = false; - bool trt_engine_cache_enable = false; - std::string trt_engine_cache_path = ""; - bool trt_engine_decryption_enable = false; - std::string trt_engine_decryption_lib_path = ""; - bool trt_force_sequential_engine_build = false; - bool trt_context_memory_sharing_enable = false; - bool trt_layer_norm_fp32_fallback = false; - bool trt_timing_cache_enable = false; - bool trt_force_timing_cache = false; - bool trt_detailed_build_log = false; - bool trt_build_heuristics_enable = false; - bool trt_sparsity_enable = false; - int trt_builder_optimization_level = 3; - int trt_auxiliary_streams = -1; - std::string trt_tactic_sources = ""; - std::string trt_extra_plugin_lib_paths = ""; - std::string trt_profile_min_shapes = ""; - std::string trt_profile_max_shapes = ""; - std::string trt_profile_opt_shapes = ""; - bool trt_cuda_graph_enable = false; + const auto& api = Ort::GetApi(); + OrtTensorRTProviderOptionsV2* tensorrt_options; + Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); + std::vector option_keys, option_values; + // used to keep all option keys and value strings alive + std::list buffer; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -207,272 +183,31 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } auto pos = token.find("|"); if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + ORT_THROW( + "[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); } - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - if (key == "device_id") { - if (!value.empty()) { - device_id = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number.\n"); - } - } else if (key == "trt_max_partition_iterations") { - if (!value.empty()) { - trt_max_partition_iterations = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_partition_iterations' should be a number.\n"); - } - } else if (key == "trt_min_subgraph_size") { - if (!value.empty()) { - trt_min_subgraph_size = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_min_subgraph_size' should be a number.\n"); - } - } else if (key == "trt_max_workspace_size") { - if (!value.empty()) { - trt_max_workspace_size = std::stoull(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_workspace_size' should be a number.\n"); - } - } else if (key == "trt_fp16_enable") { - if (value == "true" || value == "True") { - trt_fp16_enable = true; - } else if (value == "false" || value == "False") { - trt_fp16_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_int8_enable") { - if (value == "true" || value == "True") { - trt_int8_enable = true; - } else if (value == "false" || value == "False") { - trt_int8_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_int8_calibration_table_name") { - if (!value.empty()) { - trt_int8_calibration_table_name = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_calibration_table_name' should be a non-empty string.\n"); - } - } else if (key == "trt_int8_use_native_calibration_table") { - if (value == "true" || value == "True") { - trt_int8_use_native_calibration_table = true; - } else if (value == "false" || value == "False") { - trt_int8_use_native_calibration_table = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_use_native_calibration_table' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_dla_enable") { - if (value == "true" || value == "True") { - trt_dla_enable = true; - } else if (value == "false" || value == "False") { - trt_dla_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_dla_core") { - if (!value.empty()) { - trt_dla_core = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_core' should be a number.\n"); - } - } else if (key == "trt_dump_subgraphs") { - if (value == "true" || value == "True") { - trt_dump_subgraphs = true; - } else if (value == "false" || value == "False") { - trt_dump_subgraphs = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_subgraphs' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_engine_cache_enable") { - if (value == "true" || value == "True") { - trt_engine_cache_enable = true; - } else if (value == "false" || value == "False") { - trt_engine_cache_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_engine_cache_path") { - if (!value.empty()) { - trt_engine_cache_path = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a non-empty string.\n"); - } - } else if (key == "trt_engine_decryption_enable") { - if (value == "true" || value == "True") { - trt_engine_decryption_enable = true; - } else if (value == "false" || value == "False") { - trt_engine_decryption_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_engine_decryption_lib_path") { - if (!value.empty()) { - trt_engine_decryption_lib_path = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_lib_path' should be a non-empty string.\n"); - } - } else if (key == "trt_force_sequential_engine_build") { - if (value == "true" || value == "True") { - trt_force_sequential_engine_build = true; - } else if (value == "false" || value == "False") { - trt_force_sequential_engine_build = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_context_memory_sharing_enable") { - if (value == "true" || value == "True") { - trt_context_memory_sharing_enable = true; - } else if (value == "false" || value == "False") { - trt_context_memory_sharing_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_context_memory_sharing_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_layer_norm_fp32_fallback") { - if (value == "true" || value == "True") { - trt_layer_norm_fp32_fallback = true; - } else if (value == "false" || value == "False") { - trt_layer_norm_fp32_fallback = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_layer_norm_fp32_fallback' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_timing_cache_enable") { - if (value == "true" || value == "True") { - trt_timing_cache_enable = true; - } else if (value == "false" || value == "False") { - trt_timing_cache_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_force_timing_cache") { - if (value == "true" || value == "True") { - trt_force_timing_cache = true; - } else if (value == "false" || value == "False") { - trt_force_timing_cache = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_timing_cache' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_detailed_build_log") { - if (value == "true" || value == "True") { - trt_detailed_build_log = true; - } else if (value == "false" || value == "False") { - trt_detailed_build_log = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_detailed_build_log' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_build_heuristics_enable") { - if (value == "true" || value == "True") { - trt_build_heuristics_enable = true; - } else if (value == "false" || value == "False") { - trt_build_heuristics_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_build_heuristics_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_sparsity_enable") { - if (value == "true" || value == "True") { - trt_sparsity_enable = true; - } else if (value == "false" || value == "False") { - trt_sparsity_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_sparsity_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_builder_optimization_level") { - if (!value.empty()) { - trt_builder_optimization_level = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_builder_optimization_level' should be a number and default to 2.\n"); - } - } else if (key == "trt_auxiliary_streams") { - if (!value.empty()) { - trt_auxiliary_streams = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_auxiliary_streams' should be a number.\n"); - } - } else if (key == "trt_tactic_sources") { - if (!value.empty()) { - trt_tactic_sources = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a non-empty string.\n"); - } - } else if (key == "trt_extra_plugin_lib_paths") { - if (!value.empty()) { - trt_extra_plugin_lib_paths = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a non-empty string.\n"); - } - } else if (key == "trt_profile_min_shapes") { - if (!value.empty()) { - trt_profile_min_shapes = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_min_shapes' should be a non-empty string.\n"); - } - } else if (key == "trt_profile_max_shapes") { - if (!value.empty()) { - trt_profile_max_shapes = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_max_shapes' should be a non-empty string.\n"); - } - } else if (key == "trt_profile_opt_shapes") { - if (!value.empty()) { - trt_profile_opt_shapes = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a non-empty string.\n"); - } - } else if (key == "trt_cuda_graph_enable") { - if (value == "true" || value == "True") { - trt_cuda_graph_enable = true; - } else if (value == "false" || value == "False") { - trt_cuda_graph_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else { - ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_timing_cache_enable', 'trt_force_timing_cache', 'trt_detailed_build_log', 'trt_build_heuristics_enable', 'trt_sparsity_enable', 'trt_builder_optimization_level', 'trt_auxiliary_streams', 'trt_tactic_sources', 'trt_extra_plugin_lib_paths', 'trt_profile_min_shapes', 'trt_profile_max_shapes', 'trt_profile_opt_shapes', 'trt_cuda_graph_enable'] \n"); - } + buffer.emplace_back(token.substr(0, pos)); + option_keys.push_back(buffer.back().c_str()); + buffer.emplace_back(token.substr(pos + 1)); + option_values.push_back(buffer.back().c_str()); + } + + Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, + option_keys.data(), option_values.data(), option_keys.size())); + if (!status.IsOK()) { + OrtAllocator* allocator; + char* options; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + Ort::ThrowOnError(api.GetTensorRTProviderOptionsAsString(tensorrt_options, allocator, &options)); + ORT_THROW("[ERROR] [TensorRT] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), + "\nSupported options are:\n", options); } - OrtTensorRTProviderOptionsV2 tensorrt_options; - tensorrt_options.device_id = device_id; - tensorrt_options.has_user_compute_stream = 0; - tensorrt_options.user_compute_stream = nullptr; - tensorrt_options.trt_max_partition_iterations = trt_max_partition_iterations; - tensorrt_options.trt_min_subgraph_size = trt_min_subgraph_size; - tensorrt_options.trt_max_workspace_size = trt_max_workspace_size; - tensorrt_options.trt_fp16_enable = trt_fp16_enable; - tensorrt_options.trt_int8_enable = trt_int8_enable; - tensorrt_options.trt_int8_calibration_table_name = trt_int8_calibration_table_name.c_str(); - tensorrt_options.trt_int8_use_native_calibration_table = trt_int8_use_native_calibration_table; - tensorrt_options.trt_dla_enable = trt_dla_enable; - tensorrt_options.trt_dla_core = trt_dla_core; - tensorrt_options.trt_dump_subgraphs = trt_dump_subgraphs; - tensorrt_options.trt_engine_cache_enable = trt_engine_cache_enable; - tensorrt_options.trt_engine_cache_path = trt_engine_cache_path.c_str(); - tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable; - tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str(); - tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build; - tensorrt_options.trt_context_memory_sharing_enable = trt_context_memory_sharing_enable; - tensorrt_options.trt_layer_norm_fp32_fallback = trt_layer_norm_fp32_fallback; - tensorrt_options.trt_timing_cache_enable = trt_timing_cache_enable; - tensorrt_options.trt_force_timing_cache = trt_force_timing_cache; - tensorrt_options.trt_detailed_build_log = trt_detailed_build_log; - tensorrt_options.trt_build_heuristics_enable = trt_build_heuristics_enable; - tensorrt_options.trt_sparsity_enable = trt_sparsity_enable; - tensorrt_options.trt_builder_optimization_level = trt_builder_optimization_level; - tensorrt_options.trt_auxiliary_streams = trt_auxiliary_streams; - tensorrt_options.trt_tactic_sources = trt_tactic_sources.c_str(); - tensorrt_options.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str(); - tensorrt_options.trt_profile_min_shapes = trt_profile_min_shapes.c_str(); - tensorrt_options.trt_profile_max_shapes = trt_profile_max_shapes.c_str(); - tensorrt_options.trt_profile_opt_shapes = trt_profile_opt_shapes.c_str(); - tensorrt_options.trt_cuda_graph_enable = trt_cuda_graph_enable; - - session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); + + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); OrtCUDAProviderOptions cuda_options; - cuda_options.device_id = device_id; + cuda_options.device_id = tensorrt_options->device_id; cuda_options.cudnn_conv_algo_search = static_cast(performance_test_config.run_config.cudnn_conv_algo); cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test @@ -505,8 +240,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (key == "device_type") { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { ov_options[key] = value; } else if (value.find("HETERO:") == 0) { @@ -519,17 +253,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } else if (key == "device_id") { ov_options[key] = value; - } else if (key == "enable_vpu_fast_compile") { + } else if (key == "enable_npu_fast_compile") { if (value == "true" || value == "True" || value == "false" || value == "False") { ov_options[key] = value; } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_vpu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); } } else if (key == "enable_opencl_throttling") { if (value == "true" || value == "True" || @@ -564,7 +298,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_vpu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); + ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); } } session_options.AppendExecutionProvider("OpenVINO", ov_options); diff --git a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc index e37206d6aebf2..b7cead66bd7fb 100644 --- a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc @@ -143,7 +143,7 @@ void L1NormalizationWithZeroNorm() { vector expected_output = {0.5f, 0.5f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L1NormalizationWithZeroNorm) { @@ -163,7 +163,7 @@ void L2NormalizationWithZeroNorm() { vector expected_output = {1.f, 0.f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L2NormalizationWithZeroNorm) { diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index d1a523b1eecf9..b9875b9553a55 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -762,7 +762,7 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT - test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); }; // should batch batch_size to be valid @@ -860,7 +860,7 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(RNNTest, RNN_with_invalid_activation_load_failure) { diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index c95ac1603a317..c3d91100605e9 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -66,7 +66,7 @@ TEST(CompressTest, Compress_3dims_has_extra_condition) { // has condition length = 3 > input_dim[axis] = 2 test.AddInput("condition", {3}, {0, 1, 1}); test.AddOutput("output", {2, 1, 3}, {4.0f, 5.0f, 6.0f, 10.0f, 11.0f, 12.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(CompressTest, Compress_3dims_has_extra_input) { diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index 2120da604f94a..d2aa5dd428fec 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -99,7 +99,7 @@ TEST(TensorOpTest, Unsqueeze_scalar_2) { test.AddInput("input", {}, std::vector{1.0f}); test.AddInput("axes", {2}, std::vector{0, -1}, axes_is_initializer); test.AddOutput("output", {1, 1}, std::vector{1.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; run_test(false); run_test(true); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index d45323190c514..06da2a5304716 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -42,12 +42,16 @@ struct ConvTransposeOp { test->AddAttribute("pads", padding); if (!output_padding.empty()) { test->AddAttribute("output_padding", output_padding); + } else { + output_padding = {0, 0, 0, 0}; } std::vector output_dims = { input_dims[0], channels, - (kernel_shape[1] - 1) * dilations[1] + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1, - (kernel_shape[0] - 1) * dilations[0] + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1}; + (kernel_shape[1] - 1) * dilations[1] + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1 + + output_padding[2], + (kernel_shape[0] - 1) * dilations[0] + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1 + + output_padding[3]}; std::vector output_data = FillZeros(output_dims); test->AddOutput("Y", output_dims, output_data); diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index 3d1f81e6bc282..e0d59901da80c 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -31,8 +31,8 @@ struct PoolOp { std::vector output_dims = { input_dims[0], channels, - (kernel_shape[1] - 1) + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1, - (kernel_shape[0] - 1) + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1}; + (input_dims[2] - (kernel_shape[0] - 1) + padding[1] + padding[0] - 1) / strides[0] + 1, + (input_dims[3] - (kernel_shape[1] - 1) + padding[3] + padding[2] - 1) / strides[1] + 1}; std::vector output_data = FillZeros(output_dims); test->AddOutput("Y", output_dims, output_data); diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc index 9b65ca7bda3e2..b4e8f5390787c 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -175,13 +175,7 @@ static void RunBatchNormQDQTest(const TestInputDef& input_def, // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 3. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 4. -// Output quant params: scale=0.019084848463535309, zero_point=9. -// Expected val: 1.7755576372146606 -// QNN QDQ val: 2.9963212013244629 (err 1.2207635641098022) -// CPU QDQ val: 0.82064849138259888 (err 0.95490914583206177) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { +TEST_F(QnnHTPBackendTests, BatchNorm1D) { constexpr int64_t num_channels = 2; RunBatchNormQDQTest(TestInputDef({1, num_channels, 3}, false, {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data @@ -193,13 +187,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 14. -// Output quant params: scale=0.023071292787790298, zero_point=19. -// Expected val: 2.8554618358612061 -// QNN QDQ val: 5.3294687271118164 (err 2.4740068912506104) -// CPU QDQ val: 1.6611330509185791 (err 1.194328784942627) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm2D) { +TEST_F(QnnHTPBackendTests, BatchNorm2D) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; @@ -226,4 +214,4 @@ TEST_F(QnnHTPBackendTests, BatchNorm3D) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index aa96e1533653e..d9f917f6d187c 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -590,6 +590,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { // uint64_t compilation_without_cache_ms, compilation_with_cache_ms; // First session is created with TRT EP with timing cache enabled + // Not specifying a trt_timing_cache_path will result in using the working directory params.trt_timing_cache_enable = 1; { // auto start = chrono::steady_clock::now(); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index ecf4b001eec68..c48b07422d452 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -140,6 +140,9 @@ def create_backend_test(test_name=None): if backend.supports_device("OPENVINO_CPU_FP16"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16") + if backend.supports_device("OPENVINO_NPU_FP16"): + current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU_FP16") + if backend.supports_device("OPENVINO"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18") diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index a9b55122c6806..e0fb3979a9f55 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import unittest +from typing import Tuple import numpy as np import onnxscript @@ -18,6 +19,857 @@ def shard_tensor(X, rank, axis, num_shards): return np.split(X, num_shards, axis)[rank] +def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): + if axis is None: + return X + shards = np.split(X, len(device_mesh), axis) + selected_shards = tuple(shard for device_id, shard in zip(device_mesh, shards) if device_id == rank) + return np.concatenate(selected_shards, axis=axis) + + +def translate_device_mesh_to_attrs(device_mesh: np.ndarray): + device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" + device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" + return device_mesh_shape, device_mesh_elements + + +def parse_sharding_spec(spec: str): + axis_conditions = [] + sharding_device_axes = [] + token_index = 0 + while True: + token = spec[token_index] + if token == "R": + axis_conditions.append("R") + sharding_device_axes.append(None) + token_index += 1 + elif token == "S": + axis_conditions.append("S") + # Move token pointer to "["" + token_index += 1 + assert spec[token_index] == "[" + number_tokens = "" + while True: + token_index += 1 + token = spec[token_index] + if token == "]": + break + number_tokens += token + assert spec[token_index] == "]" + # Skip "]" and point to next S/R token + token_index += 1 + sharding_device_axes.append(int(number_tokens)) + else: + raise ValueError(f"Invalid spec: {spec}") + if token_index >= len(spec): + break + return axis_conditions, sharding_device_axes + + +def find_shard_axis(axis_conditions, shard_device_axes): + sharded_axis = None + sharded_axis_count = 0 + for i, cond in enumerate(axis_conditions): + if cond == "S": + sharded_axis = i + sharded_axis_count += 1 + assert sharded_axis_count in (0, 1), "Can shard at most one axis per tensor." + if sharded_axis is not None: + assert shard_device_axes[sharded_axis] == 0, "Device mesh must be 1-D, so 0 is the only valid device mesh axis." + return sharded_axis + + +def shard_tensor_per_spec(tensor: np.ndarray, rank: int, spec: str, device_mesh: np.ndarray): + axis_conditions, shard_device_axes = parse_sharding_spec(spec) + sharded_axis = find_shard_axis(axis_conditions, shard_device_axes) + return shard_tensor_per_device_mesh(tensor, rank, sharded_axis, list(device_mesh.flat)) + + +class TestDistributedReshape(unittest.TestCase): + def _check_distributed_reshape( + self, + shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + input_device_meshs: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshs: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) + assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) + assert len(input_device_meshs) == len(input_shard_specs) + assert len(output_device_meshs) == len(output_shard_specs) + + input_device_mesh_shapes = [] + input_device_mesh_elements = [] + for device_mesh in input_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + input_device_mesh_shapes.append(device_mesh_shape) + input_device_mesh_elements.append(device_mesh_element) + + output_device_mesh_shapes = [] + output_device_mesh_elements = [] + for device_mesh in output_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + output_device_mesh_shapes.append(device_mesh_shape) + output_device_mesh_elements.append(device_mesh_element) + + @onnxscript.script() + def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): + return MICROSOFT_OPSET.DistributedReshape( + data_tensor, + shape_tensor, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + shape_tensor = np.array( + target_shape, + dtype=np.int64, + ) + + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + assert "S" not in input_shard_specs[1], "Shape should not be sharded." + + expected = np.reshape(data_tensor, shape_tensor) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + + onnx_model = distributed_reshape_instance.to_model_proto( + input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data_tensor, + "shape_tensor": shape_tensor, + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): + # Two axis fusion. + # S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + ), + target_shape=(6,), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): + # Two axis fusion. + # RS[0], shape=[2, 4], device_mesh=[0, 1] -> S[0], shape = [8], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=( + 2, + 4, + ), + target_shape=(8,), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): + # Two axis fusion. + # S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + 5, + ), + target_shape=( + 2, + 15, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): + # Two axis fusion. + # RS[0]R, shape=[2, 4, 5], device_mesh=[0, 1] -> RS[0], shape = [2, 20], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 4, + 5, + ), + target_shape=( + 2, + 20, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): + # Two axis fusion. + # RRS[0], shape=[2, 3, 6], device_mesh=[0, 1] -> RS[0], shape = [2, 18], device_mesh=[0, 1, 0, 1, 0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + 6, + ), + target_shape=( + 2, + 18, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + # Two axis fusion. + # RRS[0], shape=[2, 3, 8], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1] * 3 + + # Two axis fusion. + # RS[0]R, shape=[2, 8, 3], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1] + + def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): + # Two axis decomposition + # S[0], shape=[6], device_mesh=[0, 1] -> S[0]R, shape=[2, 3], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(6,), + target_shape=( + 2, + 3, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 1, + 16, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[2, 8], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 2, + 8, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 4, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 8, + 2, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 16, + 1, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1, 0, 1] + + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 1, + 16, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): + # Two axis decomposition + # repeats=2 8 = repeats * [unique IDs] + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[2, 8], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 2, + 8, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 4, + 4, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 8, + 2, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 16, + 1, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01(self): + # Two axis decomposition + # [21, 4096] -> [3, 7, 4096] + # data: (21, 2048), (RS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRS, [0, 1]) + self._check_distributed_reshape( + shape=( + 21, + 4096, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rrsr_01(self): + # Two axis decomposition + # [3, 7, 4096] -> [3, 7, 64, 64] + # data: (3, 7, 2048), (RRS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRSR, [0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 3, + 7, + 64, + 64, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]R",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self): + # Two axis fusion + # [3, 7, 4096] -> [21, 4096] + # data: (3, 7, 4096), (RRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RR, [0, 1]) + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 21, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self): + # Two axis fusion + # [21, 4096] -> [3, 7, 4096] + # data: (21, 4096), (RR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRR, [0, 1]) + self._check_distributed_reshape( + shape=( + 21, + 4096, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRR",), + ) + + def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_010101(self): + # Two axis fusion + # [3, 64, 7, 64] -> [192, 7, 64] + # data: (3, 32, 7, 64), (RSRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 64, + 7, + 64, + ), + target_shape=( + 192, + 7, + 64, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("S[0]RR",), + ) + + def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_rsrr_01(self): + # Two axis decomposition + # [192, 7, 7] -> [3, 64, 7, 7] + # data: (96, 7, 7), (SRR, [0, 1, 0, 1, 0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RSRR, [0.0, 1.0]) + + self._check_distributed_reshape( + shape=( + 192, + 7, + 7, + ), + target_shape=( + 3, + 64, + 7, + 7, + ), + input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101(self): + # Two axis fusion + # [3, 64, 7, 7] -> [192, 7, 7] + # data: (3, 32, 7, 7), (RSRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 64, + 7, + 7, + ), + target_shape=( + 192, + 7, + 7, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("S[0]RR",), + ) + + def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_64_rsrr_01(self): + # Two axis decomposition + # [192, 7, 64] -> [3, 64, 7, 64] + # data: (96, 7, 64), (SRR, [0, 1, 0, 1, 0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RSRR, [0.0, 1.0]) + + self._check_distributed_reshape( + shape=( + 192, + 7, + 64, + ), + target_shape=( + 3, + 64, + 7, + 64, + ), + input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(self): + # Two axis fusion + # [3, 7, 64, 64] -> [3, 7, 4096] + # data: (3, 7, 32, 64), (RRSR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRS, [0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 7, + 64, + 64, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self): + # Two axis fusion + # [3, 7, 4096] -> [21, 4096] + # data: (3, 7, 2048), (RRS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RS, [0, 1]) + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 21, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + +class TestDistributedExpand(unittest.TestCase): + def _check_distributed_expand( + self, + shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + input_device_meshs: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshs: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) + assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) + assert len(input_device_meshs) == len(input_shard_specs) + assert len(output_device_meshs) == len(output_shard_specs) + + input_device_mesh_shapes = [] + input_device_mesh_elements = [] + for device_mesh in input_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + input_device_mesh_shapes.append(device_mesh_shape) + input_device_mesh_elements.append(device_mesh_element) + + output_device_mesh_shapes = [] + output_device_mesh_elements = [] + for device_mesh in output_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + output_device_mesh_shapes.append(device_mesh_shape) + output_device_mesh_elements.append(device_mesh_element) + + @onnxscript.script() + def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): + return MICROSOFT_OPSET.DistributedExpand( + data_tensor, + shape_tensor, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + shape_tensor = np.array( + target_shape, + dtype=np.int64, + ) + + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + assert "S" not in input_shard_specs[1], "Shape should not be sharded." + + expected = data_tensor * np.ones(shape_tensor) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + + onnx_model = distributed_expand_instance.to_model_proto( + input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data_tensor, + "shape_tensor": shape_tensor, + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_expand_sharded_on_expanded_axis(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RS, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 8, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RS, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 8, + 8, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_expand_replicated_on_expanded_axis(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RR, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 1, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_expand_with_pass_through_sharding_spec(self): + # data: shape=[8,1], spec=(SR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(SR, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 1, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=( + "S[0]R", + "R", + ), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_expand_in_tiny_llama(self): + # data: shape=[2,4,256,4], spec=(RSRR, [0,1]) + # shape: shape=[4], spec=(R, [0,1,2,3]), value=[2,4,256,4] + # output: shape=[2,4,256,4], spec=(RSRR, [0,1]) + self._check_distributed_expand( + shape=( + 2, + 4, + 256, + 4, + ), + target_shape=( + 2, + 4, + 256, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py new file mode 100644 index 0000000000000..784ae8ce70bd8 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# pylint: disable=C0116,W0212,R1720,C0103,C0114 +# +# Note: the precision is different on V100, H100 even with the same code. +# The thresholds were adjusted on H100 as the precision seems lower on this machine. + +import itertools +import unittest +import warnings + +import numpy as np +import parameterized +from numpy.testing import assert_allclose +from onnx import TensorProto +from onnx.checker import check_model +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info +from onnx.numpy_helper import from_array + +from onnxruntime import InferenceSession + + +class TestFloat8Gemm8(unittest.TestCase): + def get_model_gemm( + self, + float_name, + alpha=1.0, + beta=0.0, + transA=0, + transB=0, + domain="", + dtype=TensorProto.FLOAT, + activation="NONE", + ): + proto_type = getattr(TensorProto, float_name) + use_f8 = proto_type in (TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2) + + a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) + b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) + d = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) + + inits = [] + kwargs = {} + node_inputs = ["Af", "Bf"] + inputs = [a, b] + bias = beta != 0 + if bias: + inputs.append(make_tensor_value_info("C", TensorProto.FLOAT, [None, None])) + node_inputs = ["Af", "Bf", "Cf"] + if use_f8: + node_inputs.extends(["one"] * 3) + elif use_f8: + node_inputs.append("") + node_inputs.extend(["one"] * 3) + + if use_f8: + assert domain == "com.microsoft" + inits.append(from_array(np.array([1], dtype=np.float32), name="one")) + kwargs = dict( + domain=domain, + dtype=dtype, + ) + if activation is not None: + kwargs["activation"] = activation + op_name = "GemmFloat8" + elif domain == "com.microsoft": + op_name = "GemmFloat8" + kwargs = dict( + domain=domain, + dtype=dtype, + ) + else: + op_name = "Gemm" + nodes = [ + make_node("Cast", ["A"], ["Af"], to=proto_type), + make_node("Cast", ["B"], ["Bf"], to=proto_type), + make_node("Cast", ["C"], ["Cf"], to=proto_type) if bias else None, + make_node( + op_name, + node_inputs, + ["Yf"], + transA=transA, + transB=transB, + alpha=alpha, + beta=beta, + **kwargs, + ), + make_node("Cast", ["Yf"], ["Y"], to=TensorProto.FLOAT), + ] + nodes = [n for n in nodes if n is not None] + graph = make_graph(nodes, "gemm", inputs, [d], inits) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9) + if domain != "com.microsoft": + check_model(onnx_model) + return onnx_model + + def common_test_model_gemm(self, float_type, mul=0.33, atol=0, rtol=0, square=True, **kwargs): + if square: + a = (np.arange(256) * 0.01).astype(np.float32).reshape((-1, 16)) + b = (np.arange(256) * -0.01).astype(np.float32).reshape((-1, 16)) + c = (np.arange(256) * 0.03).astype(np.float32).reshape((-1, 16)) + b[:, 0] += 1 + else: + a = (np.arange(256) / 256).astype(np.float32).reshape((32, -1)) + b = (np.arange(512) / 512).astype(np.float32).reshape((32, -1)) + c = (np.arange(128) / 128).astype(np.float32).reshape((8, 16)) + + feeds = {"A": a, "B": b} + + expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b) + expected *= kwargs.get("alpha", 1.0) + if kwargs.get("beta", 0) != 0: + expected += kwargs["beta"] * c + feeds["C"] = c + + onnx_model = self.get_model_gemm("FLOAT", **kwargs) + + ref = InferenceSession( + onnx_model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + y = ref.run(None, feeds)[0] + if float_type in ("FLOAT", "FLOAT16"): + try: + assert_allclose(expected, y, atol=atol, rtol=rtol) + except Exception as e: + + def check(f): + try: + return f()[:2, :2] + except Exception as e: + return str(e) + + raise AssertionError( + f"Gemm ERROR len(inputs)={len(feeds)}" + f"\na@b=\n{check(lambda:a@b)}" + f"\na.T@b=\n{check(lambda:a.T@b)}" + f"\na@b.T=\n{check(lambda:a@b.T)}" + f"\na.T@b.T=\n{check(lambda:a.T@b.T)}" + f"\n----\nb@a=\n{check(lambda:b@a)}" + f"\nb.T@a=\n{check(lambda:b.T@a)}" + f"\nb@a.T=\n{check(lambda:b@a.T)}" + f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}" + f"\n----\nexpected=\n{expected[:2,:2]}" + f"\n----\ngot=\n{y[:2,:2]}" + f"\nkwargs={kwargs}" + ) from e + + self.assertEqual(expected.shape, y.shape) + self.assertEqual(expected.dtype, y.dtype) + + onnx_model_f8 = self.get_model_gemm(float_type, domain="com.microsoft", **kwargs) + try: + ref8 = InferenceSession( + onnx_model_f8.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + except Exception as e: + if "CUDA < 12.0 does not support bias" in str(e): + return + raise AssertionError(f"Could not load model {onnx_model_f8}") from e + try: + y = ref8.run(None, feeds)[0] + except Exception as e: + if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e): + # Skipping. This machine does not support float8. + warnings.warn("unable to test with float8 on this machine.") + return + raise AssertionError(f"Could not execute model {onnx_model_f8}") from e + try: + assert_allclose(expected, y, atol=atol, rtol=rtol) + except Exception as e: + + def check(f): + try: + return f()[:2, :2] + except Exception as e: + return str(e) + + raise AssertionError( + f"Gemm ERROR len(inputs)={len(feeds)}" + f"\na@b=\n{check(lambda:a@b)}" + f"\na.T@b=\n{check(lambda:a.T@b)}" + f"\na@b.T=\n{check(lambda:a@b.T)}" + f"\na.T@b.T=\n{check(lambda:a.T@b.T)}" + f"\n----\nb@a=\n{check(lambda:b@a)}" + f"\nb.T@a=\n{check(lambda:b.T@a)}" + f"\nb@a.T=\n{check(lambda:b@a.T)}" + f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}" + f"\n----\nexpected=\n{expected[:2,:2]}" + f"\n----\ngot=\n{y[:2,:2]}" + f"\nkwargs={kwargs}" + ) from e + self.assertEqual(expected.shape, y.shape) + self.assertEqual(expected.dtype, y.dtype) + + def test_model_gemm_float(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3) + + def test_model_gemm_float_default_values(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None) + + def test_model_gemm_float_relu(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU") + + def test_model_gemm_float_gelu(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU") + + def test_model_gemm_float_bias(self): + self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3) + + def test_model_gemm_float16(self): + self.common_test_model_gemm( + "FLOAT16", + rtol=1e-2, + dtype=TensorProto.FLOAT16, + transB=1, + ) + + def test_model_gemm_float8_e4m3(self): + self.common_test_model_gemm( + "FLOAT8E4M3FN", + rtol=0.5, + dtype=TensorProto.FLOAT, + transA=0, + transB=1, + alpha=10.0, + ) + + @parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1]))) + def test_combinations_square_matrices(self, transA, transB): + self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3) + + @parameterized.parameterized.expand( + [ + ((2, 3), (3, 5), 0, 0), + ((2, 3), (5, 3), 0, 1), + ((2, 3), (5, 2), 1, 1), + ((2, 3), (2, 5), 1, 0), + ] + ) + def test_combinations(self, shapeA, shapeB, transA, transB): + model = make_model( + make_graph( + [ + make_node( + "GemmFloat8", + ["A", "B"], + ["Y"], + transA=transA, + transB=transB, + domain="com.microsoft", + ) + ], + "f8", + [ + make_tensor_value_info("A", TensorProto.FLOAT, [None, None]), + make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), + ], + [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])], + ) + ) + + sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + a = np.arange(np.prod(shapeA)).reshape(shapeA).astype(np.float32) + b = np.arange(np.prod(shapeB)).reshape(shapeB).astype(np.float32) + try: + expected = (a.T if transA else a) @ (b.T if transB else b) + except Exception as e: + raise AssertionError( + f"Unable to multiply shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}" + ) from e + try: + got = sess.run(None, {"A": a, "B": b}) + except Exception as e: + raise AssertionError( + f"Unable to run Gemm with shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}" + ) from e + self.assertEqual(expected.shape, got[0].shape) + self.assertEqual(expected.dtype, got[0].dtype) + assert_allclose(expected, got[0]) + + +if __name__ == "__main__": + # TestFloat8Gemm8().test_model_gemm_float() + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py new file mode 100644 index 0000000000000..bf295a65c8b53 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -0,0 +1,541 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import statistics +from dataclasses import dataclass +from enum import Enum +from time import perf_counter +from typing import Optional, Tuple + +import numpy +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +torch.manual_seed(0) + + +class GroupNormOpType(Enum): + GROUP_NORM = 1 + SKIP_GROUP_NORM = 2 + + +@dataclass +class GroupNormConfig: + batch_size: int + height: int + width: int + channels: int + epsilon: float = 1e-5 + num_groups: int = 32 + activation: bool = False + channels_last: bool = True + fp16: bool = False + + op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM + has_bias: bool = False + has_add_out: bool = False + broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C) + + def get_skip_symbolic_shape(self): + skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]} + return skip_shape[self.broadcast_skip] + + def get_skip_shape(self): + skip_shape = { + 0: [self.batch_size, self.height, self.width, self.channels], + 2: [self.batch_size, self.channels], + 4: [self.batch_size, 1, 1, self.channels], + } + return skip_shape[self.broadcast_skip] + + def broadcast(self, skip: torch.Tensor): + if self.broadcast_skip == 2: + return skip.reshape(self.batch_size, 1, 1, self.channels) + + return skip + + @staticmethod + def create( + b: int, + h: int, + w: int, + c: int, + fp16: bool = False, + activation: bool = False, + template: int = 0, + num_groups: int = 32, + ): + if template == 0: + return GroupNormConfig( + b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups + ) + + if template == 1: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 2: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=2, + num_groups=num_groups, + ) + + if template == 3: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=False, + broadcast_skip=4, + num_groups=num_groups, + ) + + if template == 4: # No bias + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 5: # No bias, no add_out + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=0, + num_groups=num_groups, + ) + + return None + + +def create_group_norm_graph(config: GroupNormConfig) -> bytes: + inputs = ["input", "gamma", "beta"] + outputs = ["output"] + op_type = "GroupNorm" + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + op_type = "SkipGroupNorm" + inputs = [*inputs, "skip"] + if config.has_bias: + inputs = [*inputs, "bias"] + if config.has_add_out: + outputs = [*outputs, "add_out"] + + nodes = [ + helper.make_node( + op_type, + inputs, + outputs, + op_type + "_0", + activation=int(config.activation), + channels_last=int(config.channels_last), + epsilon=config.epsilon, + groups=config.num_groups, + domain="com.microsoft", + ), + ] + + float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT + + input_shapes = [ + helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]), + ] + output_shapes = [ + helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]), + ] + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + input_shapes = [ + *input_shapes, + helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()), + ] + if config.has_bias: + input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])] + if config.has_add_out: + output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])] + + graph = helper.make_graph( + nodes, + "Group_Norm_Graph", + input_shapes, + output_shapes, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def group_norm_ort( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, + measure_latency=False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]: + onnx_model_str = create_group_norm_graph(config) + ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"]) + + session = CudaSession(ort_session, device=torch.device("cuda:0")) + + io_shape = { + "input": [config.batch_size, config.height, config.width, config.channels], + "gamma": [config.channels], + "beta": [config.channels], + "output": [config.batch_size, config.height, config.width, config.channels], + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + io_shape["skip"] = config.get_skip_shape() + if config.has_bias: + io_shape["bias"] = [config.channels] + if config.has_add_out: + io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels] + + session.allocate_buffers(io_shape) + + ort_inputs = { + "input": src, + "gamma": gamma, + "beta": beta, + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + ort_inputs["skip"] = skip + if config.has_bias: + ort_inputs["bias"] = bias + + ort_outputs = session.infer(ort_inputs) + output = ort_outputs["output"] + + add_out = ( + ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None + ) + + if measure_latency: + latency_list = [] + for _ in range(10000): + start_time = perf_counter() + session.infer(ort_inputs) + end_time = perf_counter() + latency_list.append(end_time - start_time) + average_latency = statistics.mean(latency_list) + return output, add_out, average_latency + + return output, add_out, None + + +def group_norm_torch( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + add_out = src + + if skip is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + config.broadcast(skip) + + if bias is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0]) + + x = add_out + if config.channels_last: + x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW + + weight = gamma.to(x.dtype) + bias = beta.to(x.dtype) + output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon) + + if config.activation: + torch.nn.functional.silu(output, inplace=True) + + if config.channels_last: + output = output.permute(0, 2, 3, 1) # from NCHW to NHWC + + return output, add_out + + +def print_tensor(name, tensor): + # Print in the format that could be directly added to unit tests in C++. + torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000) + print(name) + if tensor is not None: + print("shape", tensor.shape) + text = str(tensor.clone().flatten()) + print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,")) + else: + print(tensor) + + +def run_parity(config, measure_latency=True, verbose=False): + float_type = torch.float16 if config.fp16 else torch.float32 + + input_tensor = torch.randn( + config.batch_size, + config.height, + config.width, + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + gamma = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + beta = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + skip = None + bias = None + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + skip = torch.randn( + *config.get_skip_shape(), + device="cuda", + dtype=float_type, + requires_grad=False, + ) + if config.has_bias: + bias = torch.randn( + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + if verbose: + print(config) + print_tensor("input", input_tensor) + print_tensor("gamma", gamma) + print_tensor("beta", beta) + print_tensor("skip", skip) + print_tensor("bias", bias) + + out_ort, ort_add_out, latency = group_norm_ort( + input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency + ) + + if verbose: + print_tensor("out_ort", out_ort) + print_tensor("ort_add_out", ort_add_out) + + torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config) + + if verbose: + print_tensor("torch_out", torch_out) + print_tensor("torch_add_out", torch_add_out) + + average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy())) + + is_close = numpy.allclose( + out_ort.detach().cpu().numpy(), + torch_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + + is_add_out_close = ( + numpy.allclose( + ort_add_out.detach().cpu().numpy(), + torch_add_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + if ort_add_out is not None + else "" + ) + + # Compare results + print( + config.op_type.name, + " B:", + config.batch_size, + " H:", + config.height, + " W:", + config.width, + " C:", + config.channels, + " G:", + config.num_groups, + " activation:", + int(config.activation), + " channels_last:", + int(config.channels_last), + " fp16:", + int(config.fp16), + f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "", + " AvgDiff:", + average_diff, + " Pass:", + is_close, + is_add_out_close, + ) + + +def get_latent_height_width(): + default_size = [(512, 512), (768, 768), (1024, 1024)] + small_img_size = [(512, 768), (768, 512)] + xl_img_size = [ + (1152, 896), + (896, 1152), + (1216, 832), + (832, 1216), + (1344, 768), + (768, 1344), + (1536, 640), + (640, 1536), + ] + return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size] + + +def get_channels(): + return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304] + + +def run_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32") + for b in [2]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_no_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32") + for b in [1, 2, 4]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_all_groups(template: int, fp16, measure_latency=False): + group_sizes = [1, 2, 4, 8, 16, 32] + print("Test GroupNorm for different group sizes:", group_sizes) + for group_size in group_sizes: + for h, w in get_latent_height_width()[:3]: + for c in get_channels()[:2]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_odd_channels(template: int, fp16, measure_latency=False): + # Test some random number of channels that can be divisible by 2 * num_groups + for h, w in get_latent_height_width(): + for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_small_inputs(template: int, fp16): + config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + +def run_performance(fp16): + # Run perf test to tune parameters for given number of channels. + for h, w in get_latent_height_width()[:3]: + for c in get_channels(): + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0) + run_parity(config, measure_latency=True) + + +def run_all(template: int): + for fp16 in [True, False]: + run_small_inputs(template, fp16) + run_odd_channels(template, fp16) + run_all_groups(template, fp16) + run_activation(template, fp16) + run_no_activation(template, fp16) + + +def run_not_implemented(): + # Expect failure. Check whether the error message is expected. + try: + config = GroupNormConfig(1, 2, 2, 513, num_groups=3) + run_parity(config) + except RuntimeError as e: + assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e) + + +def main(): + run_performance(True) + + run_not_implemented() + + for template in range(6): + run_all(template) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ba282193c5ca6..33d50f90333cf 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3323,6 +3323,22 @@ TEST(LiteCustomOpTest, CustomFunc) { ASSERT_TRUE(floats_output[1] == 16); } +TEST(LiteCustomOpTest, CustomFuncOpsetMismatch) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + EXPECT_THROW(Ort::Session(*ort_env, TSTR("testdata/fuse_select_filter_opset_8.onnx"), session_options), std::exception); +} + struct Merge { Merge(const OrtApi* ort_api, const OrtKernelInfo* info) { int64_t reverse; diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index ad99b675c7d20..85edfa0e59f1d 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -94,23 +94,28 @@ void Select(const Ort::Custom::Span& indices_in, } } -void Filter(const Ort::Custom::Tensor& floats_in, - Ort::Custom::Tensor& floats_out) { - const float* in = floats_in.Data(); - auto in_len = floats_in.NumberOfElement(); +struct Filter { + Filter(const OrtApi*, const OrtKernelInfo*) {} + Ort::Status Compute(const Ort::Custom::Tensor& floats_in, + Ort::Custom::Tensor& floats_out) { + const float* in = floats_in.Data(); + auto in_len = floats_in.NumberOfElement(); + + std::vector filter_floats; + for (int64_t i = 0; i < in_len; ++i) { + if (in[i] > 1.f) { + filter_floats.push_back(in[i]); + } + } - std::vector filter_floats; - for (int64_t i = 0; i < in_len; ++i) { - if (in[i] > 1.f) { - filter_floats.push_back(in[i]); + float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); + for (size_t j = 0; j < filter_floats.size(); ++j) { + out[j] = filter_floats[j]; } - } - float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); - for (size_t j = 0; j < filter_floats.size(); ++j) { - out[j] = filter_floats[j]; + return Ort::Status{nullptr}; } -} +}; void Box(const Ort::Custom::Tensor* float_in_1, const Ort::Custom::Tensor* float_in_2, @@ -293,9 +298,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; static const std::unique_ptr c_MulTopOpFloat{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; static const std::unique_ptr c_MulTopOpInt32{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; - static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse)}; + static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse, {}, 10, 12)}; static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; - static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; + static const std::unique_ptr c_Filter{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", 15, 17)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; static const std::unique_ptr c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic)}; static const std::unique_ptr c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined)}; @@ -314,7 +319,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_MulTopOpInt32.get()); domain.Add(c_Fuse.get()); domain.Add(c_Select.get()); - domain.Add(c_Fill.get()); + domain.Add(c_Filter.get()); domain.Add(c_Box.get()); domain.Add(c_CopyTensorArrayAllVariadic.get()); domain.Add(c_CopyTensorArrayCombined.get()); diff --git a/onnxruntime/test/testdata/fuse_select_filter.onnx b/onnxruntime/test/testdata/fuse_select_filter.onnx index 15d7dd64788d3..0b881228edb9d 100644 --- a/onnxruntime/test/testdata/fuse_select_filter.onnx +++ b/onnxruntime/test/testdata/fuse_select_filter.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä P vector_1 vector_2 @@ -25,4 +25,5 @@ N ÿÿÿÿÿÿÿÿÿb& vector_filtered  - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file diff --git a/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx new file mode 100644 index 0000000000000..3ea27767eb9f5 --- /dev/null +++ b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx @@ -0,0 +1,29 @@ + :Ä +P +vector_1 +vector_2 +alpha vector_fused fuse_node"Fuse* + fuse_algo :v2 +4 +indicesindices_selected select_node"Select:v2 +N + vector_fused +indices_selectedvector_gathered gather_node"GatherElements +; +vector_gatheredvector_filtered filter_node"Filter:v2graphZ +vector_1 + + ÿÿÿÿÿÿÿÿÿZ +vector_2 + + ÿÿÿÿÿÿÿÿÿZ +alpha + + ÿÿÿÿÿÿÿÿÿZ +indices + + ÿÿÿÿÿÿÿÿÿb& +vector_filtered + + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 44db7c0078cfc..c552ec3aea72d 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -521,6 +521,10 @@ "test_scan_sum_cpu", // Disabled due to output mismatch with tolerance. "test_scan9_sum_cpu" // Disabled due to output mismatch with tolerance. ], + "current_failing_tests_OPENVINO_NPU_FP16": [ + "^test_prelu_broadcast", + "test_loop11_cpu" + ], "current_failing_tests_OPENVINO_opset18": [ // pending opset 18 support, RUNTIME_EXCEPTION : Encountered unknown exception in Initialize() "^test_center_crop_pad_crop_axes_chw", diff --git a/onnxruntime/test/testdata/transform/transpose_graph_gen.py b/onnxruntime/test/testdata/transform/transpose_graph_gen.py new file mode 100644 index 0000000000000..14f2994a1925d --- /dev/null +++ b/onnxruntime/test/testdata/transform/transpose_graph_gen.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import onnx +from onnx import TensorProto, helper + + +def GenerateModel(model_name, valid): # noqa: N802 + nodes = [ + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[2, 1, 3, 0]), + helper.make_node("Add", ["transposed_input_0", "input_1"], ["output"]), + ] + + if valid: + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [1, 1, 3, 3]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 1, 3, 1]), + ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 1, 3, 1])] + else: + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [1, 2, 3, 3]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 3, 1]), + ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 2, 3, 1])] + + graph = helper.make_graph( + nodes, + "TransposeAndAdd", # name + inputs, + outputs, + [], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +GenerateModel("transpose_to_reshape_valid.onnx", True) +GenerateModel("transpose_to_reshape_invalid.onnx", False) diff --git a/onnxruntime/test/testdata/transform/transpose_to_reshape_invalid.onnx b/onnxruntime/test/testdata/transform/transpose_to_reshape_invalid.onnx new file mode 100644 index 0000000000000..a09b13fc184a8 Binary files /dev/null and b/onnxruntime/test/testdata/transform/transpose_to_reshape_invalid.onnx differ diff --git a/onnxruntime/test/testdata/transform/transpose_to_reshape_valid.onnx b/onnxruntime/test/testdata/transform/transpose_to_reshape_valid.onnx new file mode 100644 index 0000000000000..344d18ac10f77 Binary files /dev/null and b/onnxruntime/test/testdata/transform/transpose_to_reshape_valid.onnx differ diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index c90acfdb7bb78..80d937fa163e6 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4180,6 +4180,7 @@ Return true if all elements are true and false otherwise. .Attr("func_name", "Function name of the Python Triton kernel.", AttributeProto::STRING, std::string("")) .Attr("onnx_key", "The hash key for the ONNX graph.", AttributeProto::INT, static_cast(0)) .Attr("onnx_string", "The onnx string of the triton kernel.", AttributeProto::STRING, std::string("")) + .AllowUncheckedAttributes() .Input(0, "inputs", "Input tensors. If to call an existing Python Triton kernel, " "the input count and order should match the arguments of the function. If to compute an ONNX graph, " diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index e5c65b2a96d8c..57d76577f1ba7 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -63,6 +63,7 @@ #include "orttraining/core/optimizer/scaled_sum_fusion.h" #include "orttraining/core/optimizer/shape_optimizer.h" #include "orttraining/core/optimizer/transformer_layer_recompute.h" +#include "orttraining/core/optimizer/transpose_replacement.h" #include "core/optimizer/compute_optimizer/upstream_gather.h" #include "core/optimizer/compute_optimizer/upstream_reshape.h" #include "core/optimizer/pre_shape_node_elimination.h" @@ -203,6 +204,7 @@ std::vector> GeneratePreTrainingTransformers( std::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), compatible_eps); ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); + ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); } break; case TransformerLevel::Level3: { diff --git a/orttraining/orttraining/core/optimizer/transpose_replacement..cc b/orttraining/orttraining/core/optimizer/transpose_replacement..cc new file mode 100644 index 0000000000000..48e9c4d6e6a07 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/transpose_replacement..cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/transpose_replacement.h" + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status TransposeReplacement::Apply(Graph& graph, + Node& transpose_node, + RewriteRuleEffect& rule_effect, + const logging::Logger& logger) const { + auto& transpose_inputs = transpose_node.MutableInputDefs(); + auto& transpose_outputs = transpose_node.MutableOutputDefs(); + NodeArg* input = transpose_inputs[0]; + auto input_shape = input->Shape(); + if (!input_shape) { + LOG_DEBUG_INFO(logger, "Exit TransposeReplacement optimization for input shape is None."); + return Status::OK(); + } + auto perm = graph_utils::onnx_repeated_values::RetrieveValues(transpose_node.GetAttributes().at("perm")); + InlinedVector new_shape; + new_shape.reserve(perm.size()); + int64_t last_permuted_axis = 0; + for (int i = 0; i < static_cast(perm.size()); ++i) { + if (!input_shape->dim(static_cast(perm[i])).has_dim_value()) { + LOG_DEBUG_INFO(logger, "Exit TransposeReplacement optimization for not supporting symbolic shape."); + return Status::OK(); + } + new_shape.push_back(input_shape->dim(static_cast(perm[i])).dim_value()); + if (input_shape->dim(static_cast(perm[i])).dim_value() == 1) + continue; + if (perm[i] < last_permuted_axis) { + LOG_DEBUG_INFO(logger, "Exit TransposeReplacement optimization for not supporting shape."); + return Status::OK(); + } + last_permuted_axis = perm[i]; + } + + transpose_inputs.push_back( + optimizer::compute_optimizer::CreateInitializerFromVector(graph, + {static_cast(new_shape.size())}, + new_shape, + graph.GenerateNodeArgName("transpose_reshape_shape"))); + + Node& transpose_reshape_node = graph.AddNode(graph.GenerateNodeName("Transpose_Reshape"), + "Reshape", + "Transpose replaced Reshape", + transpose_inputs, + transpose_outputs, + nullptr, + kOnnxDomain); + transpose_reshape_node.SetExecutionProviderType(transpose_node.GetExecutionProviderType()); + graph_utils::FinalizeNodeFusion(graph, transpose_reshape_node, transpose_node); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} + +bool TransposeReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/transpose_replacement.h b/orttraining/orttraining/core/optimizer/transpose_replacement.h new file mode 100644 index 0000000000000..c38e402339823 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/transpose_replacement.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/compute_optimizer/shared_utils.h" + +namespace onnxruntime { + +/** +@Class TransposeReplacement + +Transpose is equivalent to a Reshape if: + empty dimensions (which dim_value=1) can change place, not empty dimensions must be in + the same order in the permuted tenosr. + Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + +This Rewrite rule replaces Transpose which meets the requirments with Reshape. +Because Transpose need memory copy while Reshape needn't, this replacement can save overhead for memory copy. + +It is attempted to be triggered only on nodes with op type "Transpose". +*/ +class TransposeReplacement : public RewriteRule { + public: + TransposeReplacement() noexcept : RewriteRule("TransposeReplacement") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Transpose"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index 97318ea2e53ae..c1b99e4859dbd 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -3,15 +3,28 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out -from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel +import os -__all__ = [ +from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 +from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401 + +_all_kernels = [ "triton_gemm", "triton_gemm_out", "triton_matmul", "triton_matmul_out", "slice_scel", "slice_scel_backward", - "transform_slice_scel", ] + +_all_optimizers = [ + "optimize_graph_for_slice_scel", +] + +if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: + from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401 + + _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) + _all_optimizers.append("optimize_graph_for_flash_attention") + +__all__ = _all_kernels + _all_optimizers # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py new file mode 100644 index 0000000000000..03bb0f4373d8d --- /dev/null +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py @@ -0,0 +1,1244 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math +from typing import List, Tuple + +import torch +import triton +import triton.language as tl +from onnx import GraphProto, NodeProto, TensorProto, helper + +from onnxruntime.training.ortmodule import register_graph_optimizer +from onnxruntime.training.ortmodule.graph_optimizers.utils import GraphMatcher, check_attribute_value, update_graph + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + Lse, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != "none": + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could cause the problem. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != "none": + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == "none": + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == "matrix": + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != "none": + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def flash_attn_forward(q, k, v, bias=None, **kwargs): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + + causal = kwargs.get("causal", 0) == 1 + softmax_scale = kwargs.get("softmax_scale", 1.0 / math.sqrt(d)) + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)") + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse + + +def flash_attn_backward(do, q, k, v, o, lse, bias=None, **kwargs): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + + causal = kwargs.get("causal", 0) == 1 + softmax_scale = kwargs.get("softmax_scale", 1.0 / math.sqrt(d)) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=128, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)") + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + bias, + do, + dq_accum, + dk, + dv, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + return dq, dk, dv + + +def _make_flash_attention_nodes( + idx: int, + q: str, + k: str, + v: str, + y: str, + dy: str, + dq: str, + dk: str, + dv: str, + bias: str, + scale: float, +): + logsumexp = helper.make_tensor_value_info("logsumexp_" + str(idx), TensorProto.FLOAT, []) + fwd_node = helper.make_node( + "TritonOp", + [q, k, v, bias], + [y, logsumexp.name], + "TritonOp_Flash_Attn_Fwd_" + str(idx), + None, + "com.microsoft", + func_name="flash_attn_forward", + causal=0, + softmax_scale=scale, + ) + bwd_node = helper.make_node( + "TritonOp", + [dy, q, k, v, y, logsumexp.name, bias], + [dq, dk, dv], + "TritonOp_Flash_Attn_Bwd_" + str(idx), + None, + "com.microsoft", + func_name="flash_attn_backward", + causal=0, + softmax_scale=scale, + ) + return [fwd_node, bwd_node], [logsumexp] + + +# Without causal mask, without Dropout. For example, BERT model in HuggingFace. +_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("Identity", False, [(10, 0, 0)]), # 11 + ("Div", False, [(11, 0, 0)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15 + ("FusedMatMul", False, [(5, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(14, 0, 0)]), # 18 + ("Transpose", False, [(15, 0, 0)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 +] + + +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[19].output[0], + nodes[20].output[0], + nodes[4].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + return nodes, nodes_to_add, new_value_infos + + +# llama2+peft, k doesn't require grad. +_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 1)]), # 1 + ("Div", False, [(0, 0, 0)]), # 2 + ("Add", False, [(2, 0, 0)]), # 3 + ("Softmax", False, [(3, 0, 0)]), # 4 + ("MatMul", False, [(4, 0, 0)]), # 5 + ("Transpose", True, [(5, 0, 1)]), # 6 + ("Identity", False, [(6, 0, 0)]), # 7 + ("YieldOp", False, [(7, 0, -1)]), # 8 + ("Transpose", False, [(5, 0, 0)]), # 9 + ("FusedMatMul", False, [(6, 0, 1)]), # 10 + ("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11 + ("Identity", False, [(11, 0, 0)]), # 12 + ("Div", False, [(12, 0, 0)]), # 13 + ("Identity", False, [(13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(4, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Sum", False, [(16, 0, 0)]), # 18 + ("Transpose", False, [(18, 0, 0)]), # 19 +] + + +def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[2].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 1, 3, 2]) + and scale_value is not None + and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + and matcher.get_consumer_count(nodes[14].output[0]) == 1 + ): + return [], [], [] + + dtype, _ = matcher.get_type_and_shape(nodes[0].input[0]) + assert dtype is not None + trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None) + trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None) + trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None) + trans_q = helper.make_node( + "Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_q_grad = helper.make_node( + "Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k = helper.make_node( + "Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3] + ) + nodes[19].input[0] = nodes[18].input[1] + v_grad = nodes[19].output[0] + nodes[19].output[0] = nodes[18].output[0] + nodes[18].input[1] = nodes[18].output[0] + nodes[18].output[0] = v_grad + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + trans_q_tensor.name, + trans_k_tensor.name, + nodes[6].input[0], + nodes[9].output[0], + nodes[17].input[0], + trans_q_grad_tensor.name, + "", + nodes[16].output[0], + nodes[3].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + nodes_to_remove = nodes[:6] + nodes[9:18] + nodes_to_add.extend([trans_q, trans_q_grad, trans_k]) + new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor]) + return nodes_to_remove, nodes_to_add, new_value_infos + + +# llama2+peft, k requires grad. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 1)]), # 1 + ("Div", False, [(0, 0, 0)]), # 2 + ("Add", False, [(2, 0, 0)]), # 3 + ("Softmax", False, [(3, 0, 0)]), # 4 + ("MatMul", False, [(4, 0, 0)]), # 5 + ("Transpose", True, [(5, 0, 1)]), # 6 + ("Identity", False, [(6, 0, 0)]), # 7 + ("YieldOp", False, [(7, 0, -1)]), # 8 + ("Transpose", False, [(5, 0, 0)]), # 9 + ("FusedMatMul", False, [(6, 0, 1)]), # 10 + ("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11 + ("Identity", False, [(11, 0, 0)]), # 12 + ("Div", False, [(12, 0, 0)]), # 13 + ("Identity", False, [(13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(14, 0, 1)]), # 16 + ("Transpose", False, [(16, 0, 0)]), # 17 + ("FusedMatMul", False, [(4, 0, 0)]), # 18 + ("Transpose", True, [(18, 0, 1)]), # 19 + ("Sum", False, [(18, 0, 0)]), # 20 + ("Transpose", False, [(20, 0, 0)]), # 21 +] + + +def _aptimize_for_pattern_2(matcher: GraphProto, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[2].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 1, 3, 2]) + and scale_value is not None + and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + and matcher.get_consumer_count(nodes[14].output[0]) == 2 + ): + return [], [], [] + + dtype, _ = matcher.get_type_and_shape(nodes[0].input[0]) + assert dtype is not None + trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None) + trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None) + trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None) + trans_k_grad_tensor = helper.make_tensor_value_info("trans_k_grad_" + str(idx), dtype, None) + trans_q = helper.make_node( + "Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_q_grad = helper.make_node( + "Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k = helper.make_node( + "Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k_grad = helper.make_node( + "Transpose", [trans_k_grad_tensor.name], [nodes[17].output[0]], "Trans_K_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + nodes[21].input[0] = nodes[20].input[1] + v_grad = nodes[21].output[0] + nodes[21].output[0] = nodes[20].output[0] + nodes[20].input[1] = nodes[20].output[0] + nodes[20].output[0] = v_grad + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + trans_q_tensor.name, + trans_k_tensor.name, + nodes[6].input[0], + nodes[9].output[0], + nodes[19].input[0], + trans_q_grad_tensor.name, + trans_k_grad_tensor.name, + nodes[18].output[0], + nodes[3].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + nodes_to_remove = nodes[:6] + nodes[9:20] + nodes_to_add.extend([trans_q, trans_q_grad, trans_k, trans_k_grad]) + new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor, trans_k_grad_tensor]) + return nodes_to_remove, nodes_to_add, new_value_infos + + +# TODO: add pattern to support attention with causal mask, such as GPT2 in HuggingFace. +_PATTERNS = [ + (_PATTERN_0, _optimize_for_pattern_0), + (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _aptimize_for_pattern_2), +] + + +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_flash_attention(graph: GraphProto): + nodes_to_remove = [] + nodes_to_add = [] + new_value_infos = [] + matcher = GraphMatcher(graph) + idx = 0 + for pattern_tuple in _PATTERNS: + for nodes in matcher.match_pattern(pattern_tuple[0]): + remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes) + if len(add_nodes) > 0: + nodes_to_remove.extend(remove_nodes) + nodes_to_add.extend(add_nodes) + new_value_infos.extend(add_value_infos) + idx += 1 + update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py index 8edcc9b63ef4f..fb7ddc68900c9 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py @@ -11,7 +11,7 @@ import triton.language as tl from onnx import TensorProto, helper -from onnxruntime.training.ortmodule import register_graph_transformer +from onnxruntime.training.ortmodule import register_graph_optimizer from .._utils import get_attribute, to_numpy_array @@ -246,8 +246,8 @@ def _get_shape_related_nodes(graph, start_arg, sub_graph_nodes): args.append(output) -@register_graph_transformer(devices="cuda") -def transform_slice_scel(graph): +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_slice_scel(graph): remove_nodes = [] triton_nodes = [] value_infos = [] diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 59cf05bb082fc..fbf1b7c2bac42 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -124,7 +124,8 @@ def _are_deterministic_algorithms_enabled(): return ORTMODULE_IS_DETERMINISTIC -from .graph_transformer_registry import register_graph_transformer # noqa: E402, F401 +from .graph_optimizer_registry import register_graph_optimizer # noqa: E402, F401 +from .graph_optimizers import * # noqa: E402, F403 from .options import DebugOptions, LogLevel # noqa: E402, F401 # ORTModule must be loaded only after all validation passes diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3953d342f1897..e0f11e5aa407e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -21,7 +21,7 @@ from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results -from .graph_transformer_registry import GraphTransformerRegistry +from .graph_optimizer_registry import GraphOptimizerRegistry from .options import DebugOptions, _SkipCheck @@ -369,7 +369,7 @@ def _build_graph(self, graph_transformer_config): device_type = self._device.type if device_type == "cuda" and self.is_rocm_pytorch: device_type = "rocm" - GraphTransformerRegistry.transform_all( + GraphOptimizerRegistry.optimize_all( type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph ) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py new file mode 100644 index 0000000000000..897ecac148bfb --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Callable + +from onnx.onnx_ml_pb2 import GraphProto + + +class GraphOptimizerRegistry: + _OPTIMIZER_FUNCS = {} # noqa: RUF012 + + @classmethod + def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]): + modules = [] + if target_modules == "all": + modules.append("all") + else: + modules = target_modules.split("|") + for module in modules: + if module in cls._OPTIMIZER_FUNCS: + cls._OPTIMIZER_FUNCS[module].append((fn, devices, priority)) + else: + cls._OPTIMIZER_FUNCS[module] = [(fn, devices, priority)] + + @classmethod + def optimize_all(cls, module_name: str, device: str, graph: GraphProto): + optimizers_to_apply = [] + if "all" in cls._OPTIMIZER_FUNCS: + optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS["all"]) + if module_name in cls._OPTIMIZER_FUNCS: + optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS[module_name]) + optimizers_to_apply = [x for x in optimizers_to_apply if x[1] == "all" or device in x[1]] + optimizers_to_apply.sort(key=lambda x: x[2], reverse=True) + for fn, _, _ in optimizers_to_apply: + fn(graph) + + +# target_modules can be multiple module names separated by "|", or "all" means apply to all modules. +# devices can be multiple device types separated by "|" or "all" means apply to all devices. +def register_graph_optimizer(target_modules: str = "all", devices: str = "all", priority: int = 0): + def graph_optimizer_wrapper(fn): + GraphOptimizerRegistry.register(target_modules, devices, priority, fn) + return fn + + return graph_optimizer_wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py new file mode 100644 index 0000000000000..d215e12f8137a --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os + +_all_optimizers = [] + +if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1: + from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 + + _all_optimizers.append("optimize_graph_for_aten_efficient_attention") + +__all__ = _all_optimizers # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py new file mode 100644 index 0000000000000..94bd41293b427 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -0,0 +1,414 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation +is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to +run is you are using PyTorch with older versions. + +PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add +support if we want to try in the future. +""" + +from typing import List, Tuple + +from onnx import GraphProto, NodeProto, TensorProto, helper + +from ..graph_optimizer_registry import register_graph_optimizer +from .utils import GraphMatcher, check_attribute_value, make_constant_node, update_graph + + +def _make_efficient_attention_nodes( + idx: int, + q: str, + k: str, + v: str, + y: str, + dy: str, + dq: str, + dk: str, + dv: str, + bias: str, + expand_bias: bool, + scale: float, + dropout_ratio: float, + causal: bool, +): + nodes_to_add = [] + scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) + dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) + causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) + int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) + true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) + false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) + logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, []) + seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, []) + offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, []) + new_value_infos = [logsumexp, seed, offset] + if expand_bias: + shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1) + shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3) + shape_2 = helper.make_node("Shape", [q], ["shape_2_" + str(idx)], start=1, end=2) + shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2) + concat = helper.make_node( + "Concat", + ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)], + ["concated_shape_" + str(idx)], + axis=0, + ) + expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)]) + nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand]) + bias = "expanded_bias_" + str(idx) + fwd_node = helper.make_node( + "ATen", + [ + q, + k, + v, + bias, + "", + "", + "", + dropout_ratio_node.output[0], + causal_node.output[0], + true_node.output[0], + scale_node.output[0], + "", + "", + ], + [y, logsumexp.name, seed.name, offset.name], + "efficient_attention_forward_" + str(idx), + None, + "org.pytorch.aten", + operator="_efficient_attention_forward", + ) + bwd_node = helper.make_node( + "ATen", + [ + dy, + q, + k, + v, + bias, + y, + "", + "", + int_zero_node.output[0], + int_zero_node.output[0], + logsumexp.name, + dropout_ratio_node.output[0], + seed.name, + offset.name, + causal_node.output[0], + false_node.output[0], + scale_node.output[0], + "", + ], + [dq, dk, dv, ""], + "efficient_attention_backward_" + str(idx), + None, + "org.pytorch.aten", + operator="_efficient_attention_backward", + ) + nodes_to_add.extend( + [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] + ) + return nodes_to_add, new_value_infos + + +# Without causal mask, with Dropout. For example, BERT model in HuggingFace. +_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("Dropout", False, [(5, 0, 0)]), # 6 + ("MatMul", False, [(6, 0, 0)]), # 7 + ("Transpose", True, [(7, 0, 1)]), # 8 + ("Transpose", False, [(7, 0, 0)]), # 9 + ("FusedMatMul", False, [(8, 0, 1)]), # 10 + ("DropoutGrad", False, [(10, 0, 0), (6, 1, 1)]), # 11 + ("SoftmaxGrad_13", False, [(11, 0, 0), (5, 0, 1)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("Div", False, [(13, 0, 0)]), # 14 + ("Identity", False, [(14, 0, 0)]), # 15 + ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]), # 16 + ("FusedMatMul", False, [(1, 0, 0), (15, 0, 1)]), # 17 + ("FusedMatMul", False, [(6, 0, 0)]), # 18 + ("Transpose", True, [(18, 0, 1)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 + ("Transpose", False, [(17, 0, 0)]), # 21 + ("Transpose", False, [(18, 0, 0)]), # 22 +] + + +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + ratio_value = matcher.get_constant_value(nodes[6].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and ratio_value is not None + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + _, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0]) + _, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1]) + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[8].input[0], + nodes[9].output[0], + nodes[19].input[0], + nodes[20].output[0], + nodes[21].output[0], + nodes[22].output[0], + nodes[4].input[1], + add_input_shape_0 != add_input_shape_1, + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ratio_value, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Without causal mask, without Dropout. For example, BERT model and disabling attention dropout in HuggingFace. +_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("Identity", False, [(10, 0, 0)]), # 11 + ("Div", False, [(11, 0, 0)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15 + ("FusedMatMul", False, [(5, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(14, 0, 0)]), # 18 + ("Transpose", False, [(15, 0, 0)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 +] + + +def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + _, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0]) + _, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1]) + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[19].output[0], + nodes[20].output[0], + nodes[4].input[1], + add_input_shape_0 != add_input_shape_1, + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# No causal mask, no attention mask, without Dropout. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Softmax", False, [(0, 0, 0)]), # 7 + ("Cast", False, [(7, 0, 0)]), # 8 + ("MatMul", False, [(8, 0, 0)]), # 9 + ("Transpose", True, [(9, 0, 1)]), # 10 + ("Transpose", False, [(9, 0, 0)]), # 11 + ("FusedMatMul", False, [(10, 0, 1)]), # 12 + ("Cast", False, [(12, 0, 0)]), # 13 + ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 + ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 + ("Mul", False, [(15, 0, 0)]), # 17 + ("Mul", False, [(16, 0, 0)]), # 18 + ("Identity", False, [(17, 0, 0)]), # 19 + ("Identity", False, [(18, 0, 0)]), # 20 + ("Cast", False, [(19, 0, 0)]), # 21 + ("Cast", False, [(20, 0, 0)]), # 22 + ("Transpose", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("FusedMatMul", False, [(8, 0, 0)]), # 25 + ("Transpose", True, [(25, 0, 1)]), # 26 + ("Transpose", False, [(25, 0, 0)]), # 27 +] + + +def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[8], "to", 10) + and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[10].input[0], + nodes[11].output[0], + nodes[26].input[0], + nodes[23].output[0], + nodes[24].output[0], + nodes[27].output[0], + "", + False, + scale_value_1, + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Has causal mask, no attention mask, without Dropout. +_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Add", False, [(0, 0, 0)]), # 7 + ("Cast", True, [(7, 0, 1)]), # 8 + ("Slice", True, [(8, 0, 0)]), # 9 + ("Slice", True, [(9, 0, 0)]), # 10 + ("Unsqueeze", True, [(9, 0, 2)]), # 11 + ("Gather", True, [(11, 0, 0)]), # 12 + ("Shape", True, [(12, 0, 0)]), # 13 + ("Softmax", False, [(7, 0, 0)]), # 14 + ("Cast", False, [(14, 0, 0)]), # 15 + ("MatMul", False, [(15, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(17, 0, 1)]), # 19 + ("Cast", False, [(19, 0, 0)]), # 20 + ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 + ("Identity", False, [(21, 0, 0)]), # 22 + ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 + ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 + ("Mul", False, [(23, 0, 0)]), # 25 + ("Mul", False, [(24, 0, 0)]), # 26 + ("Identity", False, [(25, 0, 0)]), # 27 + ("Identity", False, [(26, 0, 0)]), # 28 + ("Cast", False, [(27, 0, 0)]), # 29 + ("Cast", False, [(28, 0, 0)]), # 30 + ("Transpose", False, [(29, 0, 0)]), # 31 + ("Transpose", False, [(30, 0, 0)]), # 32 + ("FusedMatMul", False, [(15, 0, 0)]), # 33 + ("Transpose", True, [(33, 0, 1)]), # 34 + ("Transpose", False, [(33, 0, 0)]), # 35 +] + + +def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[15], "to", 10) + and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[34].input[0], + nodes[31].output[0], + nodes[32].output[0], + nodes[35].output[0], + "", + False, + scale_value_1, + 0.0, + True, + ) + return nodes, nodes_to_add, new_value_infos + + +_PATTERNS = [ + (_PATTERN_0, _optimize_for_pattern_0), + (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _optimize_for_pattern_2), + (_PATTERN_3, _optimize_for_pattern_3), +] + + +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_aten_efficient_attention(graph: GraphProto): + nodes_to_remove = [] + nodes_to_add = [] + new_value_infos = [] + matcher = GraphMatcher(graph) + idx = 0 + for pattern_tuple in _PATTERNS: + for nodes in matcher.match_pattern(pattern_tuple[0]): + remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes) + if len(add_nodes) > 0: + nodes_to_remove.extend(remove_nodes) + nodes_to_add.extend(add_nodes) + new_value_infos.extend(add_value_infos) + idx += 1 + update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py new file mode 100644 index 0000000000000..e6e5ce56773e1 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py @@ -0,0 +1,178 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import itertools +from typing import Any, Dict, List, Sequence, Tuple + +import numpy as np +from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper + + +def _get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) -> Any: + """Get attribute value from node by attribute key.""" + found = [attr for attr in node.attribute if attr.name == attr_name] + if found: + return helper.get_attribute_value(found[0]) + return default_value + + +def _to_numpy_array(node: Any) -> np.ndarray: + """Convert Constant node or TensorProto to Python value.""" + tensor = node + if isinstance(node, NodeProto): + tensor = _get_attribute(node, "value") + assert isinstance(tensor, TensorProto) + return numpy_helper.to_array(tensor).tolist() + + +class GraphMatcher: + """Sub-graph matcher with given pattern. + + GraphMatcher takes an ONNX graph to initialize. It tries to match sub-graphs to a given pattern and yield + matched sub-graphs (a list of matched nodes for each sub-graph) one by one. + + Pattern is described by a list. Each entry of the list is a Tuple: + + Tuple[str, bool, List[Tuple[int, int, int]]], e.g., ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]) + + * First string is the Op type, e.g., "FusedMatMul". + * Second bool indicates it's producer node or consumer node for source node. + * There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers, + first integer is the index of the target node in the list, second integer is the output index of the edge, + and thrid integer is the input index of the edge. + + For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also + matches rest edge infos. + + Note that when lookup target node, it will only take the first matched node as target node. For example, if a source + node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned. + You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to + avoid such confusion if possible. + """ + + def __init__(self, graph: GraphProto): + self._graph: GraphProto = graph + self._op_type_to_nodes: Dict[str, List[NodeProto]] = {} + self._consumer_count: Dict[str, int] = {} + for node in graph.node: + if node.op_type not in self._op_type_to_nodes: + self._op_type_to_nodes[node.op_type] = [] + self._op_type_to_nodes[node.op_type].append(node) + for input in node.input: + self._consumer_count[input] = self._consumer_count.get(input, 0) + 1 + + def _get_producer(self, arg: str, op_type: str, output_idx: int): + for node in self._op_type_to_nodes.get(op_type, []): + if (output_idx >= 0 and len(node.output) > output_idx and node.output[output_idx] == arg) or ( + output_idx == -1 and arg in node.output + ): + return node + return None + + def _get_consumer(self, arg: str, op_type: str, input_idx: int): + for node in self._op_type_to_nodes.get(op_type, []): + if (input_idx >= 0 and len(node.input) > input_idx and node.input[input_idx] == arg) or ( + input_idx == -1 and arg in node.input + ): + return node + return None + + def get_consumer_count(self, arg: str): + return self._consumer_count.get(arg, 0) + + def get_constant_value(self, arg: str): + node_or_initializer = None + if "Constant" in self._op_type_to_nodes: + for node in self._op_type_to_nodes["Constant"]: + if arg in node.output: + node_or_initializer = node + break + if node_or_initializer is None: + for initializer in self._graph.initializer: + if arg == initializer.name: + node_or_initializer = initializer + break + if node_or_initializer is None: + return None + return _to_numpy_array(node_or_initializer) + + def get_type_and_shape(self, arg: str): + value_infos = [ + value_info + for value_info in itertools.chain(self._graph.input, self._graph.value_info) + if value_info.name == arg + ] + if len(value_infos) > 0 and value_infos[0].type.tensor_type.HasField("shape"): + shape = [] + for dim in value_infos[0].type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + return value_infos[0].type.tensor_type.elem_type, shape + initializers = [initializer for initializer in self._graph.initializer if initializer.name == arg] + if len(initializers) > 0: + return initializers[0].data_type, initializers[0].dims + return None, None + + def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + nodes = [node] + for i in range(1, len(pattern)): + next_op_type = pattern[i][0] + is_producer = pattern[i][1] + node_idx, output_idx, input_idx = pattern[i][2][0] + next_node = ( + self._get_producer(nodes[node_idx].input[input_idx], next_op_type, output_idx) + if is_producer + else self._get_consumer(nodes[node_idx].output[output_idx], next_op_type, input_idx) + ) + if next_node is None: + return [] + for j in range(1, len(pattern[i][2])): + node_idx, output_idx, input_idx = pattern[i][2][j] + assert output_idx >= 0 and input_idx >= 0 + if (not is_producer and nodes[node_idx].output[output_idx] != next_node.input[input_idx]) or ( + is_producer and next_node.output[output_idx] != nodes[node_idx].input[input_idx] + ): + return [] + nodes.append(next_node) + return nodes + + def match_pattern(self, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + for node in self._op_type_to_nodes.get(pattern[0][0], []): + result = self._match_pattern(node, pattern) + if len(result) == len(pattern): + yield result + + +def check_attribute_value(node: NodeProto, attr_name: str, expected_value: Any): + """Check if the attribute of given node has expected value.""" + value = _get_attribute(node, attr_name) + return value == expected_value + + +def make_constant_node(name: str, dtype: TensorProto.DataType, dims: Sequence[int], vals: Any): + """Create a constant node with given constant tensor (data type, shape, and data).""" + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=dtype, dims=dims, vals=vals), + ) + + +def update_graph( + graph: GraphProto, + nodes_to_remove: List[NodeProto], + nodes_to_add: List[NodeProto], + new_value_infos: List[TensorProto] = [], # noqa: B006 +): + """Update an ONNX graph by removing some nodes, and adding some new nodes and value infos.""" + nodes = [node for node in graph.node if node not in nodes_to_remove] + nodes.extend(nodes_to_add) + graph.ClearField("node") + graph.node.extend(nodes) + if len(new_value_infos) > 0: + graph.value_info.extend(new_value_infos) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py deleted file mode 100644 index 70056179c140e..0000000000000 --- a/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py +++ /dev/null @@ -1,47 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from typing import Callable - -from onnx.onnx_ml_pb2 import GraphProto - - -class GraphTransformerRegistry: - _TRANSFORMER_FUNCS = {} # noqa: RUF012 - - @classmethod - def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]): - modules = [] - if target_modules == "all": - modules.append("all") - else: - modules = target_modules.split("|") - for module in modules: - if module in cls._TRANSFORMER_FUNCS: - cls._TRANSFORMER_FUNCS[module].append((fn, devices, priority)) - else: - cls._TRANSFORMER_FUNCS[module] = [(fn, devices, priority)] - - @classmethod - def transform_all(cls, module_name: str, device: str, graph: GraphProto): - transformers_to_apply = [] - if "all" in cls._TRANSFORMER_FUNCS: - transformers_to_apply.extend(cls._TRANSFORMER_FUNCS["all"]) - if module_name in cls._TRANSFORMER_FUNCS: - transformers_to_apply.extend(cls._TRANSFORMER_FUNCS[module_name]) - transformers_to_apply = [x for x in transformers_to_apply if x[1] == "all" or device in x[1]] - transformers_to_apply.sort(key=lambda x: x[2], reverse=True) - for fn, _, _ in transformers_to_apply: - fn(graph) - - -# target_modules can be multiple module names separated by "|", or "all" means apply to all modules. -# devices can be multiple device types separated by "|" or "all" means apply to all devices. -def register_graph_transformer(target_modules: str = "all", devices: str = "all", priority: int = 0): - def graph_transformer_wrapper(fn): - GraphTransformerRegistry.register(target_modules, devices, priority, fn) - return fn - - return graph_transformer_wrapper diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 94ca87b2ac519..20b9354d85745 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -18,6 +18,7 @@ #include "orttraining/core/optimizer/concat_replacement.h" #include "orttraining/core/optimizer/batchnorm_replacement.h" #include "orttraining/core/optimizer/localized_recompute.h" +#include "orttraining/core/optimizer/transpose_replacement.h" #include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" @@ -551,6 +552,46 @@ TEST_F(GraphTransformationTests, ConcatReplacement) { ASSERT_EQ(op_to_count["com.microsoft.ConcatTraining"], 1); } +TEST_F(GraphTransformationTests, TransposeReplacement) { + { + auto model_uri = MODEL_FOLDER "transpose_to_reshape_valid.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = std::make_unique("TransposeReplacement"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 0); + ASSERT_EQ(op_to_count["Reshape"], 1); + } + + { + auto model_uri = MODEL_FOLDER "transpose_to_reshape_invalid.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = std::make_unique("TransposeReplacement"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 1); + ASSERT_EQ(op_to_count["Reshape"], 0); + } +} + TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; std::shared_ptr p_model; diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc b/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc index 28f4ff665f797..c230a0c9a3b1d 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc @@ -17,8 +17,8 @@ InlinedHashSet TritonOp::GetBoolOutputs(size_t output_size) const { InlinedHashSet bool_outputs; for (size_t i = 0; i < output_size; ++i) { ORT_ENFORCE(i < Node().OutputDefs().size(), "Output index out of range."); - if (Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == - ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) { + if (Node().OutputDefs()[i]->Exists() && Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) { bool_outputs.insert(i); } } @@ -37,13 +37,15 @@ Status TritonOp::Compute(OpKernelContext* context) const { InlinedHashSet bool_outputs = GetBoolOutputs(output_size); auto& executor = training::framework::triton::TritonOpExecutor::Instance(); if (func_name_ != "") { - executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs); + executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs, kwargs_); } else { executor.ExecuteByOnnx(onnx_key_, onnx_string_, inputs, outputs, bool_outputs); } ORT_ENFORCE(output_size == outputs.size()); for (size_t i = 0; i < output_size; ++i) { - ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), outputs[i])); + if (Node().OutputDefs()[i]->Exists()) { + ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), outputs[i])); + } } return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index 25e7b1f15ff6b..f226db76f7ed7 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -5,6 +5,8 @@ #pragma once +#include "core/common/inlined_containers.h" + #ifndef SHARED_PROVIDER #include "core/framework/op_kernel.h" #endif @@ -18,6 +20,19 @@ class TritonOp final : public OpKernel { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &func_name_)); ORT_THROW_IF_ERROR(info.GetAttr("onnx_key", &onnx_key_)); ORT_THROW_IF_ERROR(info.GetAttr("onnx_string", &onnx_string_)); + for (const auto& attr : info.node().GetAttributes()) { + if (attr.first.rfind("_", 0) == 0 || attr.first == "func_name" || attr.first == "onnx_key" || + attr.first == "onnx_string") { + continue; + } + // Support int64 and float only for now, skip other types. + if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { + kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { + kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } + } } Status Compute(OpKernelContext* context) const override; @@ -28,6 +43,7 @@ class TritonOp final : public OpKernel { std::string func_name_; int64_t onnx_key_; std::string onnx_string_; + InlinedHashMap> kwargs_; }; bool IsTritonOpExecutorInitialized(); diff --git a/pyproject.toml b/pyproject.toml index 89011a7944ab6..97515cb9fa62b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,3 +92,4 @@ unfixable = [ "tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix "onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix "onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix +"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo. diff --git a/setup.py b/setup.py index b71836e0ee6e4..9eca9845c9e8b 100644 --- a/setup.py +++ b/setup.py @@ -203,19 +203,22 @@ def run(self): "libcurand.so.10", ] rocm_dependencies = [ - "librccl.so.1", - "libnuma.so.1", "libamd_comgr.so.2", + "libamdhip64.so.5", "libdrm.so.2", - "librocblas.so.0", "libdrm_amdgpu.so.1", - "libamdhip64.so.5", - "libroctracer64.so.4", - "libMIOpen.so.1", - "libtinfo.so.6", "libelf.so.1", - "librocm_smi64.so.5", + "libhipfft.so.0", + "libhiprtc.so.5", "libhsa-runtime64.so.1", + "libMIOpen.so.1", + "libnuma.so.1", + "librccl.so.1", + "librocblas.so.3", + "librocfft.so.0", + "librocm_smi64.so.5", + "libroctracer64.so.4", + "libtinfo.so.6", ] tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] @@ -466,6 +469,7 @@ def finalize_options(self): "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", + "onnxruntime.training.ortmodule.graph_optimizers", "onnxruntime.training.ort_triton", "onnxruntime.training.ort_triton.kernel", "onnxruntime.training.utils", diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 806e536cb4ddb..a992da8ff993e 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -66,15 +66,13 @@ def _str_to_bool(s): def _openvino_verify_device_type(device_read): - choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16", "VPUX_FP16", "VPUX_U8"] + choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"] choices1 = [ "CPU_FP32_NO_PARTITION", "CPU_FP16_NO_PARTITION", "GPU_FP32_NO_PARTITION", "GPU_FP16_NO_PARTITION", - "VPUX_FP16_NO_PARTITION", - "VPUX_U8_NO_PARTITION", ] status_hetero = True res = False @@ -89,7 +87,7 @@ def _openvino_verify_device_type(device_read): if len(comma_separated_devices) < 2: print("At least two devices required in Hetero/Multi/Auto Mode") status_hetero = False - dev_options = ["CPU", "GPU", "VPUX"] + dev_options = ["CPU", "GPU"] for dev in comma_separated_devices: if dev not in dev_options: status_hetero = False @@ -100,7 +98,7 @@ def invalid_hetero_build(): print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") - print("are ['CPU','GPU', 'VPUX'] \n") + print("are ['CPU','GPU'] \n") print("An example of how to specify the hetero build type. Ex: HETERO:GPU,CPU \n") print("An example of how to specify the MULTI build type. Ex: MULTI:GPU,CPU \n") print("An example of how to specify the AUTO build type. Ex: AUTO:GPU,CPU \n") @@ -1158,8 +1156,6 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ("ON" if args.use_openvino == "GPU_FP16" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ("ON" if args.use_openvino == "CPU_FP32" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16=" + ("ON" if args.use_openvino == "CPU_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16=" + ("ON" if args.use_openvino == "VPUX_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8=" + ("ON" if args.use_openvino == "VPUX_U8" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP=" + ("ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP=" @@ -1168,9 +1164,6 @@ def generate_build_tree( + ("ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16_NP=" + ("ON" if args.use_openvino == "CPU_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16_NP=" - + ("ON" if args.use_openvino == "VPUX_FP16_NP_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8_NP=" + ("ON" if args.use_openvino == "VPUX_U8_NP_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_HETERO=" + ("ON" if args.use_openvino.startswith("HETERO") else "OFF"), "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino), "-Donnxruntime_USE_OPENVINO_MULTI=" + ("ON" if args.use_openvino.startswith("MULTI") else "OFF"), diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3696c41c196de..14a9bbedf09a0 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -490,7 +490,7 @@ stages: tools/ci_build/get_docker_image.py \ --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ --context tools/ci_build/github/linux/docker \ - --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 --build-arg INSTALL_CUDNN=true --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ + --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ --container-registry onnxruntimebuildcache \ --multiple_repos \ --repository onnxruntimecuda118xtrt86build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1d4681d064387..9e1fae343c84e 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -62,9 +62,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda11build @@ -166,7 +165,6 @@ jobs: --network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda11build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 16d4457c45eb6..517c8d638c935 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -63,7 +63,6 @@ jobs: --network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimetensorrt86gpubuild diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 1373381e4c83e..0f6310724e9a1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.97 + version: 1.0.104 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.97 + version: 1.0.104 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 0d58f6cee4003..85562d7758ab2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -48,9 +48,8 @@ stages: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda118xtrt86build diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index 33c82b5e8965a..f68847afff379 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -40,9 +40,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index a70e0c01e52f1..5dad3ad1f59a6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -85,9 +85,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index ae2a4b4cead3d..2ba4b7bea3716 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -70,6 +70,23 @@ stages: MachinePool: onnxruntime-Win2022-GPU-T4 isTraining: true +- stage: dml + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env.bat + buildArch: x64 + additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: DML + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-dml-A10 + - stage: kernelDocumentation dependsOn: [] jobs: diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 7b2cada736488..d4aa9b269095f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -1,14 +1,13 @@ -# The default ARGs are for cuda 12.2 with TRT v = 8.6.1.6-1.cuda12.0 +# The default ARGs are for cuda 11.8 with cudnn8,TensorRT is optional # Please overwirete BASEIMAGE, TRT_VERSION and other arguments with # --docker-build-args ' --build-arg BASEIMAGE=other_base_image --build-arg TRT_VERSION=other_trt_version etc...' # for other cuda version and TRT version ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 -ARG BASEIMAGE=nvidia/cuda:12.2.0-devel-ubi8 +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 ARG PREPEND_PATH=/usr/local/cuda/binet -ARG INSTALL_CUDNN=false #Build manylinux docker image begin FROM $BASEIMAGE AS runtime_base @@ -154,15 +153,6 @@ CMD ["/bin/bash"] #Build manylinux docker image end - -#Install optinal Cudnn -RUN if [ "$INSTALL_CUDNN" = true ]; then \ - CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') && \ - dnf -y install \ - libcudnn8-devel-*cuda${CUDA_VERSION}* \ - libcudnn8-*cuda${CUDA_VERSION}* ; \ -fi - #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "$TRT_VERSION" ]; then \ echo "TRT_VERSION is $TRT_VERSION" && \ diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index 5341ae062d332..680b12602910e 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 8a67692ae598b..7fa606b6c294c 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -66,7 +66,7 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 rm ~/miniconda.sh && conda clean -ya # Conda base patch -RUN pip install cryptography==41.0.0 +RUN pip install cryptography==41.0.4 # Create migraphx-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/migraphx-ci diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index b2893286803b0..8ef1fd4522973 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 5d48a93b09c90..5673bddfe058a 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -5,7 +5,7 @@ mypy pytest setuptools>=68.2.2 wheel>=0.35.1 -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx argparse sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 32bb99f08812e..412bc00d02778 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -68,7 +68,7 @@ RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} # Conda base patch -RUN pip install cryptography==41.0.0 +RUN pip install cryptography==41.0.4 # Enable rocm-ci environment SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] diff --git a/tools/ci_build/github/windows/setup_env_cuda.bat b/tools/ci_build/github/windows/setup_env_cuda.bat index 96569cbe0f648..2233f7611ab6a 100644 --- a/tools/ci_build/github/windows/setup_env_cuda.bat +++ b/tools/ci_build/github/windows/setup_env_cuda.bat @@ -1,15 +1,17 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { - set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% -} else { +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( +set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% +) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% -} +) + @REM The default version is still cuda v11.8, because set cuda v12.2 after it -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ { +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 -} else { +) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 -} +) + set GRADLE_OPTS=-Dorg.gradle.daemon=false diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 4328c6eba1fe1..49b536e6ab81e 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -1,11 +1,21 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% -} else { +) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% -} -set PATH=C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Current\Bin;%PATH% +) +set PATH=C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib;%PATH% + +@REM The default version is still cuda v11.8, because set cuda v12.2 after it +set PATH=%PATH%;C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +) else ( + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\extras\CUPTI\lib64 +) + + set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index cc27cdc293646..f7b68551b9c50 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -552,6 +552,7 @@ def generate_files(line_list, args): files_list.append( "" ) + else: files_list.append( "' - ) + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") + for dll_element in os.listdir(dll_list_path): if dll_element.endswith("dll"): files_list.append( @@ -735,26 +720,7 @@ def generate_files(line_list, args): + args.target_architecture + '\\native" />' ) - # plugins.xml - files_list.append( - "' - ) - # usb-ma2x8x.mvcmd - # OpenVINO 2022.3 doesn't have usb-ma2x8x.mvcmd - if "2022.3" not in openvino_path: - files_list.append( - "' - ) + for tbb_element in os.listdir(tbb_list_path): if tbb_element.endswith("dll"): files_list.append( diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index dcc6a92d84ef2..7a77839c4a4e7 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,10 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # not currently required, but running ensures we're hitting all mobile platforms + "Android CI Pipeline", + "iOS CI Pipeline", + "ONNX Runtime React Native CI Pipeline", ] # remove pipelines that have already run successfully