From b4ffef4c8cd75d5f68b6ff80bdf26100dbf73a6f Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 16 Nov 2023 06:56:14 +0000 Subject: [PATCH] Update ck --- cmake/deps.txt | 2 +- .../composable_kernel/Fix_Clang_Build.patch | 17 ++- .../rocm/diffusion/group_norm_ck.cuh | 2 + .../diffusion/group_norm_ck_impl/impl.cuh | 124 +++++++++--------- .../diffusion/group_norm_ck_impl/impl_fp16.cu | 13 +- .../diffusion/group_norm_ck_impl/impl_fp32.cu | 9 +- 6 files changed, 90 insertions(+), 77 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 49142372ab86e..e065cacdfc423 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/a4f72a314a85732ed67d5aa8d1088d207a7e0e61.zip;f57357ab6d300e207a632d034ebc8aa036a090d9 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 02b30af9eef52..15844dd917744 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index b09da41a8..fca2bdf69 100644 +index 04674124c..12e8b8b00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ endif() @@ -48,7 +48,18 @@ index b09da41a8..fca2bdf69 100644 ## tidy include(EnableCompilerWarnings) -@@ -489,11 +466,3 @@ rocm_install(FILES +@@ -376,7 +353,9 @@ if(BUILD_DEV) + add_compile_options(-Werror -Weverything) + endif() + #add flags to reduce the size of binaries +-add_compile_options(-Oz -flto=thin) ++# -flto requires ORT to use a linker that support LTO and -flto flag shoud be passed to linker together. ++# add_compile_options(-Oz -flto=thin) ++add_compile_options(-Oz) + message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +@@ -482,11 +461,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -61,7 +72,7 @@ index b09da41a8..fca2bdf69 100644 - HEADER_ONLY -) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index a0478c9f0..1e7782cd4 100644 +index 9cb5d0e9a..141a46f3d 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -44,8 +44,14 @@ function(add_instance_library INSTANCE_NAME) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index 0146e81c6cf8c..2e0b29cc30ad2 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -69,6 +69,8 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { gamma_beta_strides, // gammaStrides gamma_beta_strides, // betaStrides in_out_strides, // yStrides + {0, 0}, // saveMeanStrides + {0, 0}, // saveInvStdStrides reduce_dims, // reduceDims params->epsilon, params->src, diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 88443478cf521..0012f455b29f5 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -6,8 +6,8 @@ #ifdef USE_COMPOSABLE_KERNEL #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" #include "ck/utility/data_type.hpp" namespace onnxruntime { @@ -21,58 +21,60 @@ using F32 = float; using Swish = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceNormalization; // the interface -using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation +using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface +using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation + +// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp template using device_normalization_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; template -using device_normalization_f16_instances = std::tuple< +using device_normalization_f16_instances = // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; @@ -85,38 +87,38 @@ template -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { return {}; } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Swish, 5, 3>(); template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Pass, 5, 3>(); + F16, F32, F32, F16, F32, Pass, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Swish, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Pass, 5, 3>(); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu index d1dd78e3452da..6718f29268031 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 97baed34a341d..9b0ccab17b4c1 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{});