Skip to content

Commit

Permalink
Remove USE_CUTLASS flag (#19271)
Browse files Browse the repository at this point in the history
### Description
Since Cutlass can be built with CUDA 11.4 (The minimum CUDA version for
onnxruntime CUDA build), there is no need to have a flag to disable
cutlass.

Changes:
(1) Reverted #18761
(2) remove the condition to build cutlass.
(3) Fix a few build errors or warnings during testing CUDA 11.4 build. 

Note that SM 89 and 90 (including fp8) requires CUDA 11.8 or later.
Flash attention and cutlass fused multihead attention will not be built
for CUDA < 11.6. It is recommended to use CUDA 11.8 or above to build if
you want to support latest GPUs.

It is better to include it in 1.17.0 (otherwise, the release branch
might encounter build failure with CUDA 11.4).

Tests:
(1) Build with flash attention and efficient attention off: **passed**
(2) Build with CUDA 11.4: **passed**

Example build command used in Ubuntu 20.04:
```
export CUDA_HOME=/usr/local/cuda-11.4
export CUDNN_HOME=/usr/lib/x86_64-linux-gnu/
export CUDACXX=/usr/local/cuda-11.4/bin/nvcc

sh build.sh --config Release  --build_shared_lib --parallel  --use_cuda --cuda_version 11.4 \
            --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME --build_wheel --skip_tests \
            --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \
            --disable_types float8
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored Jan 26, 2024
1 parent 656ca66 commit 8b45172
Show file tree
Hide file tree
Showing 26 changed files with 25 additions and 131 deletions.
23 changes: 7 additions & 16 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)

cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

Expand Down Expand Up @@ -707,20 +706,16 @@ if (onnxruntime_USE_CUDA)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
set(onnxruntime_USE_CUTLASS OFF)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
else()
set(onnxruntime_USE_CUTLASS OFF)
endif()

if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS)
if (onnxruntime_DISABLE_CONTRIB_OPS)
message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled")
else()
message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled")
endif()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
Expand Down Expand Up @@ -906,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()

if (onnxruntime_USE_CUTLASS)
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
endif()

if(USE_NEURAL_SPEED)
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
endif()
Expand Down
20 changes: 9 additions & 11 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
if (onnxruntime_USE_CUTLASS)
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
Expand Down Expand Up @@ -204,5 +202,3 @@ Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
Expand Down Expand Up @@ -36,5 +34,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
8 changes: 0 additions & 8 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
#ifdef USE_CUTLASS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -169,10 +167,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);

#ifdef USE_CUTLASS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);
#endif

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);
Expand Down Expand Up @@ -272,10 +268,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
#ifdef USE_CUTLASS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
Expand Down Expand Up @@ -377,10 +371,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,

#ifdef USE_CUTLASS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,
#endif

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,
Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -52,5 +49,3 @@ inline int compute_occupancy_for_kernel() {
}

} // namespace ort_fastertransformer

#endif
11 changes: 4 additions & 7 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_CUTLASS

#include "cutlass_heuristic.h"

Expand Down Expand Up @@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
}

// Check that the workspace has sufficient space for this split-k factor
const int ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
const int ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
const size_t ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
const size_t ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;

if (required_ws_bytes > workspace_bytes) {
return false;
Expand Down Expand Up @@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
int current_m_tile = 0;

const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (int ii = 0; ii < candidate_configs.size(); ++ii) {
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
CutlassGemmConfig candidate_config = candidate_configs[ii];
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
int occupancy = occupancies[ii];
Expand Down Expand Up @@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
}

} // namespace ort_fastertransformer

#endif
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_CUTLASS

#pragma once

Expand All @@ -38,4 +37,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
const int multi_processor_count, const int is_weight_only);

} // namespace ort_fastertransformer
#endif
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
*
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/array.h"
Expand Down Expand Up @@ -133,5 +131,3 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
};

} // namespace ort_fastertransformer

#endif
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

namespace ort_fastertransformer {
Expand Down Expand Up @@ -58,5 +56,3 @@ struct CutlassGemmConfig {
};

} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
*
**************************************************************************************************/

#ifdef USE_CUTLASS

/*! \file
\brief Scheduler for grouped GEMM
*/
Expand Down Expand Up @@ -79,5 +77,3 @@ struct GemmMoeProblemVisitor
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/layout/matrix.h"
Expand Down Expand Up @@ -152,6 +150,4 @@ struct MixedGemmArchTraits<

} // namespace kernel
} // namespace gemm
} // namespace cutlass

#endif
} // namespace cutlass
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
*
**************************************************************************************************/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/complex.h"
Expand Down Expand Up @@ -463,5 +461,3 @@ struct MoeFCGemm {
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#endif
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -64,5 +62,3 @@ class MoeGemmRunner {
};

} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#include "moe_gemm_kernels_template.h"

namespace ort_fastertransformer {
template class MoeGemmRunner<half, half>;
} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#include "moe_gemm_kernels_template.h"

namespace ort_fastertransformer {
template class MoeGemmRunner<float, float>;
} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__
#pragma GCC diagnostic push
Expand Down Expand Up @@ -428,5 +426,3 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A, const WeightType* B, con
}

} // namespace ort_fastertransformer

#endif
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
Expand Down Expand Up @@ -900,5 +898,3 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half
cudaStream_t);

} // namespace ort_fastertransformer

#endif
6 changes: 1 addition & 5 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "moe_gemm_kernels.h"
Expand Down Expand Up @@ -174,6 +172,4 @@ class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_s
}
};

} // namespace ort_fastertransformer

#endif
} // namespace ort_fastertransformer
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
\brief Base scheduler for grouped problems, using MoE
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
Expand Down Expand Up @@ -290,5 +288,3 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
} // namespace kernel
} // namespace gemm
} // namespace cutlass

#endif
Loading

0 comments on commit 8b45172

Please sign in to comment.