Skip to content

Commit

Permalink
require cuda 11.4 for cutlass
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Nov 30, 2023
1 parent fb87cc9 commit 8aeb46c
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 46 deletions.
9 changes: 1 addition & 8 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu)
set(ORT_PROVIDER_FLAGS)
set(ORT_PROVIDER_CMAKE_FLAGS)

set(onnxruntime_USE_CUTLASS ON)
if (onnxruntime_USE_CUDA)
if (onnxruntime_USE_CUDA_NHWC_OPS)
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
Expand All @@ -701,8 +700,7 @@ if (onnxruntime_USE_CUDA)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
set(onnxruntime_USE_CUTLASS OFF)
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
Expand All @@ -724,11 +722,6 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1)
endif()
if (onnxruntime_USE_CUTLASS)
message( STATUS "Enable CUTLASS extension")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUTLASS=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUTLASS=1)
endif()
endif()

if (onnxruntime_USE_VITISAI)
Expand Down
14 changes: 6 additions & 8 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
if (onnxruntime_USE_CUTLASS)
include(FetchContent)
FetchContent_Declare(
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
6 changes: 2 additions & 4 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,8 @@
target_link_libraries(${target} PRIVATE cuda)
endif()

if (onnxruntime_USE_CUTLASS)
include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
endif()
include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class QuantBGemm {
using QuantBlocking = QuantBlocking_;
static constexpr bool kHasQOffset = !(std::is_same<ElementQOffset, std::monostate>::value);

// TODO enable uint4_t or smaller for QOffset
// TODO(chenfucn): consider moving to uint4_t or smaller for QOffset
static_assert(!kHasQOffset || std::is_same<ElementQOffset_, uint8_t>::value, "QOffset must be uint8_t");

/// Define the kernel
Expand Down Expand Up @@ -378,8 +378,7 @@ class QuantBGemm {
return Status::kErrorInternal;
}
}
}
else {
} else {

if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ class QuantBMmaMultistage :
smem_iterator_B_.add_tile_offset({1, 0});
smem_iterator_QScale_.add_tile_offset({1, 0});

if constexpr (kHasQOffset){
if constexpr (kHasQOffset) {
iterator_QOffset.add_tile_offset({1, 0});
smem_iterator_QOffset_.add_tile_offset({1, 0});
}
Expand All @@ -664,7 +664,7 @@ class QuantBMmaMultistage :
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0});
if constexpr (kHasQOffset){
if constexpr (kHasQOffset) {
smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0});
}
smem_write_stage_idx_ = 0;
Expand Down Expand Up @@ -703,7 +703,7 @@ class QuantBMmaMultistage :
static_assert(IteratorQOffset::kAccessesPerVector == 1,
"Quant offset should 1 access per vector!");

if constexpr(kHasQOffset){
if constexpr(kHasQOffset) {
// Async Copy for quantization offset
typename IteratorQOffset::AccessType *dst_ptr =
reinterpret_cast<typename IteratorQOffset::AccessType *>(
Expand Down Expand Up @@ -872,7 +872,7 @@ class QuantBMmaMultistage :
cutlass::arch::cp_async<kSrcBytes, kCacheOpQScale>(
dst_ptr, gmem_ptr, iterator_QScale.valid());

if constexpr (kHasQOffset){
if constexpr (kHasQOffset) {
iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_);

// Async Copy for quantization offset
Expand Down Expand Up @@ -907,8 +907,8 @@ class QuantBMmaMultistage :
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();

if constexpr(debug_layout){
if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1){
if constexpr(debug_layout) {
if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) {
if (threadIdx.x == 0){
printf("stage: %d\n", smem_write_stage_idx_);
}
Expand Down Expand Up @@ -957,7 +957,7 @@ class QuantBMmaMultistage :
iterator_B,
(warp_mma_k + 1) % Base::kWarpGemmIterations);

if constexpr(debug_layout){
if constexpr(debug_layout) {
if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){
printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations);
}
Expand All @@ -974,7 +974,7 @@ class QuantBMmaMultistage :
pipe_state.warp_loaded_frag_QScale_,
pipe_state.warp_loaded_frag_QOffset_);

if constexpr(debug_layout){
if constexpr(debug_layout) {
LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
}

Expand Down Expand Up @@ -1049,7 +1049,7 @@ class QuantBMmaMultistage :
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// In the case of small M, memory latency dominates. We try to move uses far
// from their definitions to hide latency.
if constexpr(debug_layout){
if constexpr(debug_layout) {
if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){
printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations);
}
Expand All @@ -1066,7 +1066,7 @@ class QuantBMmaMultistage :
pipe_state.warp_loaded_frag_QScale_,
pipe_state.warp_loaded_frag_QOffset_);

if constexpr(debug_layout){
if constexpr(debug_layout) {
LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
}

Expand Down Expand Up @@ -1159,7 +1159,7 @@ class QuantBMmaMultistage :
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_);
if constexpr(kHasQOffset){
if constexpr(kHasQOffset) {
iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_);
}

Expand All @@ -1180,9 +1180,9 @@ class QuantBMmaMultistage :

copy_tiles_and_advance(iterator_A, iterator_B, 0);

if constexpr(Shape::kM > 32){
if constexpr(Shape::kM > 32) {
// the case of bigger m
if constexpr(debug_layout){
if constexpr(debug_layout) {
if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){
printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0);
}
Expand All @@ -1199,7 +1199,7 @@ class QuantBMmaMultistage :
pipe_state.warp_loaded_frag_QScale_,
pipe_state.warp_loaded_frag_QOffset_);

if constexpr(debug_layout){
if constexpr(debug_layout) {
LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
}
} else {
Expand All @@ -1215,7 +1215,7 @@ class QuantBMmaMultistage :
// Mainloop
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
if constexpr(Shape::kM > 32){
if constexpr(Shape::kM > 32) {
mac_loop_iter(
pipe_state,
accum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
* well with CUTLASS headers.
*/

#if USE_CUTLASS

#include <random>

#include "core/framework/float16.h"
Expand Down Expand Up @@ -409,5 +407,3 @@ TEST(BlkQ4_GEMM, Sm80Test) {

} // namespace test
} // namespace onnxruntime

#endif // USE_CUTLASS
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
* well with gtest headers.
*/

#if USE_CUTLASS

#include "core/mickey/blk_q4/f16_gemm_sm80.h"

#include "cutlass/util/host_tensor.h"
Expand Down Expand Up @@ -489,5 +487,3 @@ template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k);
} // namespace test
} // namespace cuda
} // namespace onnxruntime

#endif // USE_CUTLASS

0 comments on commit 8aeb46c

Please sign in to comment.