diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 23ded3bfc1e68..34355fb0fd936 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -96,6 +96,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) +cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -693,16 +694,20 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") - if (onnxruntime_DISABLE_CONTRIB_OPS) - set(onnxruntime_USE_FLASH_ATTENTION OFF) - set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) - endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") - set(onnxruntime_USE_FLASH_ATTENTION OFF) - set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") + set(onnxruntime_USE_CUTLASS OFF) endif() else() + set(onnxruntime_USE_CUTLASS OFF) +endif() + +if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) + if (onnxruntime_DISABLE_CONTRIB_OPS) + message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") + else() + message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") + endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -887,6 +892,11 @@ function(onnxruntime_set_compile_flags target_name) if (onnxruntime_ENABLE_ATEN) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() + + if (onnxruntime_USE_CUTLASS) + target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) + endif() + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 983eecdd88235..efc708bd681c0 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,4 +1,4 @@ -if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) +if (onnxruntime_USE_CUTLASS) include(FetchContent) FetchContent_Declare( cutlass diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 40a667ffd5d83..9b989dac9a94b 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -202,3 +204,5 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index 5ea4ae59c4020..cbd483fddab78 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -34,3 +36,5 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7875ac75b8188..be7e9f6a8225e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,8 +70,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -165,8 +167,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); +#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -266,8 +270,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -367,8 +373,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 86136ea244e23..9b97690fe70fd 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -13,6 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#ifdef USE_CUTLASS + #pragma once #include @@ -49,3 +52,5 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index 5d4c6793ec995..f0abd46572a90 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -185,3 +186,5 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector @@ -62,3 +64,5 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d9a249db4237..1d0dfe7c5a647 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,8 +14,12 @@ * limitations under the License. */ +#ifdef USE_CUTLASS + #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7b250e6ca9060..7a5d97902ee8f 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,8 +14,12 @@ * limitations under the License. */ +#ifdef USE_CUTLASS + #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 66950c9b65970..3fd0fc47055a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,6 +14,8 @@ * limitations under the License. */ +#ifdef USE_CUTLASS + // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push @@ -426,3 +428,5 @@ void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, con } } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index f4f2b49032d23..9232e8d012933 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,6 +16,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include #include #include @@ -898,3 +900,5 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half cudaStream_t); } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 5cc2a3f79f003..f09471de1cc2e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,6 +16,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "moe_gemm_kernels.h" @@ -172,4 +174,6 @@ class CutlassMoeFCRunner> { } // namespace layout } // namespace cutlass + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 3f26a274109ad..0da06192e266b 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -117,3 +119,5 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index c4d8c4dc64c57..710b914f0633d 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -24,3 +26,5 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index f55a7cde2e208..dc8b9d57f79f6 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "core/common/common.h" @@ -170,3 +172,5 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index ebb0261deefa5..844cc877f2568 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -421,3 +423,5 @@ TEST(MoETest, MoETest_Relu) { } // namespace test } // namespace onnxruntime + +#endif