From fb87cc9b3c9bae2a50eac15926cd0be70590b9ca Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 30 Nov 2023 01:21:04 +0000 Subject: [PATCH] add compilation flag --- cmake/CMakeLists.txt | 12 +++++++++++- cmake/external/cutlass.cmake | 2 +- cmake/onnxruntime_providers_cuda.cmake | 2 +- .../cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc | 4 ++++ .../cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 4 ++++ 5 files changed, 21 insertions(+), 3 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5796db03fed7c..9567bb074e051 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -683,6 +683,7 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) +set(onnxruntime_USE_CUTLASS ON) if (onnxruntime_USE_CUDA) if (onnxruntime_USE_CUDA_NHWC_OPS) add_compile_definitions(ENABLE_CUDA_NHWC_OPS) @@ -699,6 +700,10 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") + set(onnxruntime_USE_CUTLASS OFF) + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) @@ -719,8 +724,13 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - + if (onnxruntime_USE_CUTLASS) + message( STATUS "Enable CUTLASS extension") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUTLASS=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUTLASS=1) + endif() endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 983eecdd88235..efc708bd681c0 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,4 +1,4 @@ -if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) +if (onnxruntime_USE_CUTLASS) include(FetchContent) FetchContent_Declare( cutlass diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 6d7f7e3cde0cd..9d6a6701ed002 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -172,7 +172,7 @@ target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) + if (onnxruntime_USE_CUTLASS) include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) endif() diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index ff6c38bb56d32..eab0aa19d82f2 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -11,6 +11,8 @@ * well with CUTLASS headers. */ +#if USE_CUTLASS + #include #include "core/framework/float16.h" @@ -407,3 +409,5 @@ TEST(BlkQ4_GEMM, Sm80Test) { } // namespace test } // namespace onnxruntime + +#endif // USE_CUTLASS diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 6dcdad67e9511..7c2f99c62370a 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -11,6 +11,8 @@ * well with gtest headers. */ +#if USE_CUTLASS + #include "core/mickey/blk_q4/f16_gemm_sm80.h" #include "cutlass/util/host_tensor.h" @@ -487,3 +489,5 @@ template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); } // namespace test } // namespace cuda } // namespace onnxruntime + +#endif // USE_CUTLASS