Skip to content

Commit

Permalink
Adding cuda kernel (optimized for sm80) for block-wise 4b quantized f…
Browse files Browse the repository at this point in the history
…loat 16 GEMM. (#18619)

### Description
Adding CUDA kernel for block-wise 4b quantized float 16 GEMM, this is
specially optimized for Nvidia Ampere GPUs.


### Motivation and Context
Trying to improve quantized LLM inference performance on Nvidia Ampere
GPUs

### Note:
This is implemented by extending CUTLASS, so it has a hard dependency on
CUTLASS. However, in current build system, loading of CUTLASS dependency
is guarded with:

(onnxruntime_USE_FLASH_ATTENTION OR
onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)

If both of these options are turned off, then compilation will fail.

Why CUTLASS dependency is guarded at all? It's a header file only
library that does not introduce any binary if not instantiated. What's
the downside of removing all the guards and just include CUTLASS
unconditionally?
  • Loading branch information
chenfucn authored Mar 5, 2024
1 parent bdf678d commit 06e684c
Show file tree
Hide file tree
Showing 25 changed files with 6,257 additions and 513 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ exclude_patterns = [
'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
'onnxruntime/core/mlas/**', # Contains assembly code
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting
'winml/lib/Api.Image/shaders/**', # Contains data chunks
]
command = [
Expand Down
5 changes: 4 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ if (onnxruntime_USE_CUDA)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
Expand All @@ -747,8 +750,8 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1)
endif()

endif()

if (onnxruntime_USE_VITISAI)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
endif()

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
endif()
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mickey/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that
are often shared by various AI operators. The intention is to make this
header files only, with no binary impact unless it is instantiated
where it is needed.

Currently cuda code are scattered in multiple locations in the repo.
Hopefully this can be the starting point of consolidating all cuda
code.
208 changes: 208 additions & 0 deletions onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/**
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*
* Module Name:
* blk_q4/f16_gemm_sm80.h
*
* Abstract:
* Entry point for Q4F16 GEMM kernel for SM80 devices.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass_ext/q4gemm/device/quantb_gemm.h"

namespace onnxruntime {
namespace cuda {

//
// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type
//
template <
typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16
typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape<x,y>
bool SmallM, // <- true if M <= 16
bool kHasQuantOffset>
struct BlkQ4F16GemmImpl {
//
// Type definitions
//

using ElementDequant = ElementDequant_;
using QuantBlocking = QuantBlocking_;

static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!");

// Data types that are fixed for this kernel
using ElementAccumulator = float;
using ElementComputeEpilogue = ElementAccumulator;
using ElementInputA = ElementDequant;
using ElementOutput = ElementDequant;

using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them

// We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators
// for async shared memory loading and minimize bank conflict
using ElementWPack = ElementDequant;

using ElementQScale = ElementDequant; // <- data type of quantization scale
using ElementQOffset = uint8_t;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputWPack = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;

// Layout of quantization scale and offset, oriented to be loaded using less instructions
// in a warp tile
using LayoutInputQScale =
typename std::conditional<QuantBlocking::kRow == 1,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type; // <- layout of quantization scale

using ShapeMMAThreadBlock =
typename std::conditional<SmallM,
cutlass::gemm::GemmShape<16, 64, 64>,
cutlass::gemm::GemmShape<128, 256, 64>>::type;

static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32;
using ShapeMMAWarp =
typename std::conditional<SmallM,
cutlass::gemm::GemmShape<16, MinN, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>::type;

using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function

// Number of pipelines you want to use
static constexpr int NumStages = 3;

using Gemm = cutlass::gemm::device::QuantBGemm<
ElementInputA,
LayoutInputA,
ElementWPack,
LayoutInputWPack,
ElementQScale,
typename std::conditional<kHasQuantOffset, ElementQOffset, std::monostate>::type,
LayoutInputQScale,
QuantBlocking,
ElementOutput,
LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;

using Arguments = typename Gemm::Arguments;

// Invoke gemm kernel (the version with quantization offset)
static cutlass::Status run(
cudaStream_t stream,
const cutlass::gemm::GemmCoord& problem_size_,
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_Qoffset_,
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
if constexpr (!kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
if constexpr (ShapeMMAThreadBlock::kM == 16) {
if (problem_size_.m() > 16) {
// For M > 16, the caller should have picked the
// kernel with bigger M
return cutlass::Status::kErrorNotSupported;
}
}

// Construct Gemm arguments
Arguments args{
problem_size_,
ref_A_,
ref_B_,
ref_Qscale_,
ref_Qoffset_,
ref_C_,
ref_D_,
epilogue_};

Gemm gemm_op;

// Check if this GEMM can be run or not
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
return status;
}

// Launch the CUTLASS GEMM kernel.
return gemm_op(args, nullptr, stream);
}
}

// Invoke gemm kernel (the version without quantization offset)
static cutlass::Status run(
cudaStream_t stream,
const cutlass::gemm::GemmCoord& problem_size_,
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
if constexpr (kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
if constexpr (ShapeMMAThreadBlock::kM == 16) {
if (problem_size_.m() > 16) {
// For M > 16, the caller should have picked the
// kernel with bigger M
return cutlass::Status::kErrorNotSupported;
}
}

// Construct Gemm arguments
Arguments args{
problem_size_,
ref_A_,
ref_B_,
ref_Qscale_,
ref_C_,
ref_D_,
epilogue_};

Gemm gemm_op;

// Check if this GEMM can be run or not
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
return status;
}

// Launch the CUTLASS GEMM kernel.
return gemm_op(args, nullptr, stream);
}
}
};

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Licensed under the MIT License.
*
* Module Name:
* prepack_sm80.h
* blk_q4/f16_prepack_sm80.h
*
* Abstract:
* Prepack weights and quantization parameters (scales and offsets) for
Expand Down
Loading

0 comments on commit 06e684c

Please sign in to comment.