Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding cuda kernel (optimized for sm80) for block-wise 4b quantized float 16 GEMM. #18619

Merged
merged 13 commits into from
Mar 5, 2024
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ exclude_patterns = [
'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
'onnxruntime/core/mlas/**', # Contains assembly code
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting
'winml/lib/Api.Image/shaders/**', # Contains data chunks
]
command = [
Expand Down
5 changes: 4 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,9 @@ if (onnxruntime_USE_CUDA)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
chenfucn marked this conversation as resolved.
Show resolved Hide resolved
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
Expand All @@ -735,8 +738,8 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1)
endif()

endif()

if (onnxruntime_USE_VITISAI)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1)
Expand Down
6 changes: 3 additions & 3 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)

FetchContent_GetProperties(cutlass)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
endif()

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
endif()
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mickey/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that
are often shared by various AI operators. The intention is to make this
header files only, with no binary impact unless it is instantiated
where it is needed.

Currently cuda code are scattered in multiple locations in the repo.
Hopefully this can be the starting point of consolidating all cuda
code.
208 changes: 208 additions & 0 deletions onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/**
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*
* Module Name:
* blk_q4/f16_gemm_sm80.h
*
* Abstract:
* Entry point for Q4F16 GEMM kernel for SM80 devices.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass_ext/q4gemm/device/quantb_gemm.h"

namespace onnxruntime {
namespace cuda {

//
// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type
//
template <
typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16
typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape<x,y>
bool SmallM, // <- true if M <= 16
bool kHasQuantOffset>
struct BlkQ4F16GemmImpl {
//
// Type definitions
//

using ElementDequant = ElementDequant_;
using QuantBlocking = QuantBlocking_;

static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!");

// Data types that are fixed for this kernel
using ElementAccumulator = float;
using ElementComputeEpilogue = ElementAccumulator;
using ElementInputA = ElementDequant;
using ElementOutput = ElementDequant;

using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them

// We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators
// for async shared memory loading and minimize bank conflict
using ElementWPack = ElementDequant;

using ElementQScale = ElementDequant; // <- data type of quantization scale
using ElementQOffset = uint8_t;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputWPack = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;

// Layout of quantization scale and offset, oriented to be loaded using less instructions
// in a warp tile
using LayoutInputQScale =
typename std::conditional<QuantBlocking::kRow == 1,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type; // <- layout of quantization scale

using ShapeMMAThreadBlock =
typename std::conditional<SmallM,
cutlass::gemm::GemmShape<16, 64, 64>,
chenfucn marked this conversation as resolved.
Show resolved Hide resolved
cutlass::gemm::GemmShape<128, 256, 64>>::type;

static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32;
using ShapeMMAWarp =
typename std::conditional<SmallM,
cutlass::gemm::GemmShape<16, MinN, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>::type;

using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function

// Number of pipelines you want to use
static constexpr int NumStages = 3;

using Gemm = cutlass::gemm::device::QuantBGemm<
ElementInputA,
LayoutInputA,
ElementWPack,
LayoutInputWPack,
ElementQScale,
typename std::conditional<kHasQuantOffset, ElementQOffset, std::monostate>::type,
LayoutInputQScale,
QuantBlocking,
ElementOutput,
LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;

using Arguments = typename Gemm::Arguments;

// Invoke gemm kernel (the version with quantization offset)
static cutlass::Status run(
cudaStream_t stream,
const cutlass::gemm::GemmCoord& problem_size_,
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_Qoffset_,
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
if constexpr (!kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
if constexpr (ShapeMMAThreadBlock::kM == 16) {
if (problem_size_.m() > 16) {
// For M > 16, the caller should have picked the
// kernel with bigger M
return cutlass::Status::kErrorNotSupported;
}
}

// Construct Gemm arguments
Arguments args{
problem_size_,
ref_A_,
ref_B_,
ref_Qscale_,
ref_Qoffset_,
ref_C_,
ref_D_,
epilogue_};

Gemm gemm_op;

// Check if this GEMM can be run or not
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
return status;
}

// Launch the CUTLASS GEMM kernel.
return gemm_op(args, nullptr, stream);
}
}

// Invoke gemm kernel (the version without quantization offset)
static cutlass::Status run(
cudaStream_t stream,
const cutlass::gemm::GemmCoord& problem_size_,
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
if constexpr (kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
if constexpr (ShapeMMAThreadBlock::kM == 16) {
if (problem_size_.m() > 16) {
// For M > 16, the caller should have picked the
// kernel with bigger M
return cutlass::Status::kErrorNotSupported;
}
}

// Construct Gemm arguments
Arguments args{
problem_size_,
ref_A_,
ref_B_,
ref_Qscale_,
ref_C_,
ref_D_,
epilogue_};

Gemm gemm_op;

// Check if this GEMM can be run or not
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
return status;
}

// Launch the CUTLASS GEMM kernel.
return gemm_op(args, nullptr, stream);
}
}
};

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Licensed under the MIT License.
*
* Module Name:
* prepack_sm80.h
* blk_q4/f16_prepack_sm80.h
*
* Abstract:
* Prepack weights and quantization parameters (scales and offsets) for
Expand Down
Loading
Loading