Skip to content

Commit

Permalink
adding cuda kernel with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Nov 29, 2023
1 parent 288b80d commit ee953a8
Show file tree
Hide file tree
Showing 21 changed files with 5,949 additions and 174 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ exclude_patterns = [
'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
'onnxruntime/core/mlas/**', # Contains assembly code
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting
'winml/lib/Api.Image/shaders/**', # Contains data chunks
]
command = [
Expand Down
6 changes: 4 additions & 2 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@
target_link_libraries(${target} PRIVATE cuda)
endif()

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
endif()

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
endif()
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mickey/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that
are often shared by various AI operators. The intention is to make this
header files only, with no binary impact unless it is instantiated
where it is needed.

Currently cuda code are scattered in multiple locations in the repo.
Hopefully this can be the starting point of consolidating all cuda
code.
208 changes: 208 additions & 0 deletions onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/**
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*
* Module Name:
* blk_q4/f16_gemm_sm80.h
*
* Abstract:
* Entry point for Q4F16 GEMM kernel for SM80 devices.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass_ext/q4gemm/device/quantb_gemm.h"

namespace onnxruntime {
namespace cuda {

//
// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type
//
template <
typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16
typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape<x,y>
bool SmallM, // <- true if M <= 16
bool kHasQuantOffset>
struct BlkQ4F16GemmImpl {
//
// Type definitions
//

using ElementDequant = ElementDequant_;
using QuantBlocking = QuantBlocking_;

static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!");

// Data types that are fixed for this kernel
using ElementAccumulator = float;
using ElementComputeEpilogue = ElementAccumulator;
using ElementInputA = ElementDequant;
using ElementOutput = ElementDequant;

using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them

// We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators
// for async shared memory loading and minimize bank conflict
using ElementWPack = ElementDequant;

using ElementQScale = ElementDequant; // <- data type of quantization scale
using ElementQOffset = uint8_t;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputWPack = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;

// Layout of quantization scale and offset, oriented to be loaded using less instructions
// in a warp tile
using LayoutInputQScale =
typename std::conditional<QuantBlocking::kRow == 1,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type; // <- layout of quantization scale

using ShapeMMAThreadBlock =
typename std::conditional<SmallM,
cutlass::gemm::GemmShape<16, 64, 64>,
cutlass::gemm::GemmShape<128, 256, 64>>::type;

static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32;
using ShapeMMAWarp =
typename std::conditional<SmallM,
cutlass::gemm::GemmShape<16, MinN, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>::type;

using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function

// Number of pipelines you want to use
static constexpr int NumStages = 3;

using Gemm = cutlass::gemm::device::QuantBGemm<
ElementInputA,
LayoutInputA,
ElementWPack,
LayoutInputWPack,
ElementQScale,
typename std::conditional<kHasQuantOffset, ElementQOffset, std::monostate>::type,
LayoutInputQScale,
QuantBlocking,
ElementOutput,
LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;

using Arguments = typename Gemm::Arguments;

// Invoke gemm kernel (the version with quantization offset)
static cutlass::Status run(
cudaStream_t stream,
const cutlass::gemm::GemmCoord& problem_size_,
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_Qoffset_,
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
if constexpr (!kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
if constexpr (ShapeMMAThreadBlock::kM == 16) {
if (problem_size_.m() > 16) {
// For M > 16, the caller should have picked the
// kernel with bigger M
return cutlass::Status::kErrorNotSupported;
}
}

// Construct Gemm arguments
Arguments args{
problem_size_,
ref_A_,
ref_B_,
ref_Qscale_,
ref_Qoffset_,
ref_C_,
ref_D_,
epilogue_};

Gemm gemm_op;

// Check if this GEMM can be run or not
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
return status;
}

// Launch the CUTLASS GEMM kernel.
return gemm_op(args, nullptr, stream);
}
}

// Invoke gemm kernel (the version without quantization offset)
static cutlass::Status run(
cudaStream_t stream,
const cutlass::gemm::GemmCoord& problem_size_,
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
if constexpr (kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
if constexpr (ShapeMMAThreadBlock::kM == 16) {
if (problem_size_.m() > 16) {
// For M > 16, the caller should have picked the
// kernel with bigger M
return cutlass::Status::kErrorNotSupported;
}
}

// Construct Gemm arguments
Arguments args{
problem_size_,
ref_A_,
ref_B_,
ref_Qscale_,
ref_C_,
ref_D_,
epilogue_};

Gemm gemm_op;

// Check if this GEMM can be run or not
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
return status;
}

// Launch the CUTLASS GEMM kernel.
return gemm_op(args, nullptr, stream);
}
}
};

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Licensed under the MIT License.
*
* Module Name:
* prepack_sm80.h
* blk_q4/f16_prepack_sm80.h
*
* Abstract:
* Prepack weights and quantization parameters (scales and offsets) for
Expand Down
Loading

0 comments on commit ee953a8

Please sign in to comment.