From d9fae0cfc52559ccb8c94e4bc57aed16188a13db Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 5 Dec 2023 17:31:26 +0000 Subject: [PATCH] Add fp16xq4 matmul sm80 cuda kernel to ORT operator And to use graph transformer as prepack --- cmake/onnxruntime_optimizer.cmake | 4 + cmake/onnxruntime_providers_cuda.cmake | 1 + cmake/onnxruntime_unittests.cmake | 7 + docs/ContribOperators.md | 8 + .../core/optimizer/graph_transformer_utils.h | 5 +- .../cuda/quantization/matmul_nbits.cc | 58 ++- .../cuda/quantization/matmul_nbits.cu | 339 +++++++++++++++ .../cuda/quantization/matmul_nbits.cuh | 11 + .../cuda/quantization/matmul_nbits.h | 21 +- .../core/graph/contrib_ops/contrib_defs.cc | 11 + onnxruntime/core/graph/graph_utils.cc | 6 - onnxruntime/core/graph/graph_utils.h | 22 +- .../core/mickey/blk_q4/f16_gemm_sm80.h | 26 +- .../core/mickey/blk_q4/f16_prepack_sm80.h | 88 +++- .../threadblock/quantb_mma_multistage.h | 2 +- .../core/mickey/gemm/kernel/quant_b4_gemm.h | 26 +- .../mickey/gemm/warp/quantb_meta_loader.h | 23 +- onnxruntime/core/optimizer/gpu_ops_prepack.cc | 326 +++++++++++++++ onnxruntime/core/optimizer/gpu_ops_prepack.h | 28 ++ .../core/optimizer/graph_transformer_utils.cc | 21 +- .../providers/cuda/cuda_provider_factory.cc | 19 +- .../providers/cuda/cuda_provider_factory.h | 3 +- onnxruntime/core/session/inference_session.cc | 16 +- onnxruntime/core/util/matrix_layout.h | 6 +- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 153 ++++++- .../test/optimizer/gpu_op_prepack_test.cc | 393 ++++++++++++++++++ .../test/optimizer/graph_transform_test.cc | 27 +- .../optimizer/graph_transform_test_builder.cc | 1 + .../optimizer/graph_transform_utils_test.cc | 16 +- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 11 +- .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 28 +- .../cuda_execution_provider_test.cc | 1 - .../cuda/test_cases/cuda_test_provider.cc | 6 + .../optimizer/graph_transformer_utils_test.cc | 16 +- 34 files changed, 1611 insertions(+), 118 deletions(-) create mode 100644 onnxruntime/core/optimizer/gpu_ops_prepack.cc create mode 100644 onnxruntime/core/optimizer/gpu_ops_prepack.h create mode 100644 onnxruntime/test/optimizer/gpu_op_prepack_test.cc diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 3bae1b8a48e0f..a2156f2c0e9f5 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -111,6 +111,10 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}) + +# using optimizer as cuda prepacking, so extra headers are needed +target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) + if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT}) onnxruntime_add_include_to_target(onnxruntime_optimizer nlohmann_json::nlohmann_json) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 3b48a40bf1166..10e88308cea53 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -217,6 +217,7 @@ include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 711a9f77f9094..1ade7f4109457 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -272,6 +272,13 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) "${TEST_SRC_DIR}/optimizer/*.h" ) + if (MSVC AND ((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC"))) + set_source_files_properties("${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" PROPERTIES COMPILE_FLAGS "/bigobj") + list(REMOVE_ITEM onnxruntime_test_optimizer_src + "${TEST_SRC_DIR}/optimizer/gpu_op_prepack_test.cc" + ) + endif() + set(onnxruntime_test_framework_src_patterns "${TEST_SRC_DIR}/framework/*.cc" "${TEST_SRC_DIR}/framework/*.h" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45306c852a906..9f2fb44b9756d 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2915,6 +2915,14 @@ This version of the operator has been available since version 1 of the 'com.micr
number of bits used for weight quantization (default 4)
block_size : int (required)
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
column_wise_blocking : int
+
whether to quantize weight columnwise (value 1, default), or rowwise (value 0)
+
prepacked : int
+
+Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default). +This property should NEVER be set by user. It is set by ONNX Runtime internally during +model loading time. +
#### Inputs (3 - 6) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..56853e3229030 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -22,6 +22,7 @@ namespace onnxruntime { class IExecutionProvider; +class ExecutionProviders; namespace optimizer_utils { @@ -48,7 +49,7 @@ std::unique_ptr GenerateRuleBasedGraphTransformer( InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, - const IExecutionProvider& execution_provider /*required by constant folding*/, + const ExecutionProviders& execution_providers /* cpu ep required by constant folding*/, const InlinedHashSet& rules_and_transformers_to_disable = {}); #endif // !defined(ORT_MINIMAL_BUILD) @@ -77,7 +78,7 @@ InlinedVector> GenerateTransformersForMinimalB TransformerLevel level, const SessionOptions& session_options, const SatApplyContextVariant& apply_context, - const IExecutionProvider& cpu_execution_provider, + const ExecutionProviders& execution_providers, const InlinedHashSet& rules_and_transformers_to_disable = {}); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 1cec6f6a12f1c..2c3db2a3d2b47 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -14,24 +14,44 @@ namespace onnxruntime { namespace contrib { namespace cuda { -using namespace onnxruntime::cuda; + +#ifndef USE_ROCM +template <> +Status MatMulNBits::PrepackedGemm( + cudaStream_t stream, + int M, + const Tensor* a, + const Tensor* b, + const Tensor* scales, + const Tensor* zero_points, + Tensor* Y) const { + uint8_t const* zero_points_ptr = nullptr; + size_t zero_points_size = 0; + if (zero_points != nullptr) { + zero_points_ptr = zero_points->Data(); + zero_points_size = zero_points->Shape().Size(); + } + + return blkq4_fp16_gemm_sm80_dispatch( + int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream, + a->Data(), a->Shape().Size(), + b->Data(), b->Shape().Size(), + scales->Data(), scales->Shape().Size(), + zero_points_ptr, zero_points_size, + Y->MutableData(), Y->Shape().Size()); +} +#endif // !USE_ROCM template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + using CudaT = typename ToCudaType::MappedType; + const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); - const auto* a_data = a->Data(); - const uint8_t* blob_data = b->Data(); - const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); - - typedef typename ToCudaType::MappedType CudaT; - constexpr bool transa = false; constexpr bool transb = true; MatMulComputeHelper helper; @@ -43,6 +63,26 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); + if (prepack_ > 0) { +#ifndef USE_ROCM + ORT_RETURN_IF(reorder_idx != nullptr, + "Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!"); + ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType(), + "Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!"); + return PrepackedGemm( + static_cast(ctx->GetComputeStream()->GetHandle()), + static_cast(helper.M()), a, b, scales, zero_points, Y); +#else + ORT_RETURN_IF(true, "Prepacked gemm is not supported for MatMulNBits op."); +#endif // !USE_ROCM + } + + const auto* a_data = a->Data(); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); + bool is_4bit_done = (reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType()) && TryMatMul4Bits( diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index af9e87eaf225d..2309ef4fb7a27 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -10,6 +10,9 @@ #include "core/providers/cuda/cuda_common.h" #include "matmul_nbits.cuh" +#include "blk_q4/f16_gemm_sm80.h" +#include "gemm/device/quant_b4_gemm.h" + using namespace onnxruntime::cuda; using namespace cub; @@ -349,6 +352,342 @@ template bool TryMatMul4Bits( int shared_mem_per_block, cudaStream_t stream); +/** + * @brief Helper function to run the GEMM kernel for 4bits quantized gemm on SM80. + * Only support fp16 for now. +*/ +template< + typename ElementT, + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, + gsl::span a, + gsl::span weights, + gsl::span scales, + gsl::span offsets, + gsl::span output) { + static_assert(std::is_same::value + || std::is_same::value + || std::is_same::value, + "Only support fp16 for now"); + using ElementDequant = cutlass::half_t; + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + + using GemmRunner = onnxruntime::cuda::BlkQ4F16GemmImpl; + + using ElementAccumulator = typename GemmRunner::ElementAccumulator; + using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; + using ElementOutput = typename GemmRunner::ElementOutput; + using ElementW = typename GemmRunner::ElementW; + using ElementWPack = typename GemmRunner::ElementWPack; + using ElementQScale = typename GemmRunner::ElementQScale; + using ElementQOffset = typename GemmRunner::ElementQOffset; + + using LayoutInputA = typename GemmRunner::LayoutInputA; + using LayoutOutput = typename GemmRunner::LayoutOutput; + using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; + using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + + ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct: ", a.size_bytes(), " vs m: ", m, "k: ", k , " size: ", m * k * sizeof(ElementDequant)); + cutlass::TensorRef ref_a( + reinterpret_cast(a.data()), + LayoutInputA::packed({m, k})); + + ORT_RETURN_IF_NOT(weights.size_bytes() == k/2 * n/2 * sizeof(ElementWPack), "weights size is not correct"); + cutlass::TensorRef ref_W( + reinterpret_cast(weights.data()), + LayoutInputWPack::packed({k/2, n/2})); + + ORT_RETURN_IF_NOT(scales.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQScale), + "scales size is not correct"); + cutlass::TensorRef ref_scales( + reinterpret_cast(scales.data()), + LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn})); + + ORT_RETURN_IF_NOT(output.size_bytes() == m * n * sizeof(ElementOutput), "output size is not correct"); + cutlass::TensorRef ref_output( + reinterpret_cast(output.data()), + LayoutOutput::packed({m, n})); + + // run GEMM + cutlass::Status status; + if constexpr (has_offsets) { + ORT_RETURN_IF_NOT(offsets.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQOffset), + "offsets size is not correct"); + cutlass::TensorRef ref_offsets( + reinterpret_cast(offsets.data()), + LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn})); + status = GemmRunner::run( + stream, problem_size, ref_a, ref_W, ref_scales, ref_offsets, + ref_output, ref_output); + } else { + status = GemmRunner::run( + stream, problem_size, ref_a, ref_W, ref_scales, + ref_output, ref_output); + } + ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + return Status::OK(); +} + +/** + * @brief The GEMM kernel for 4bits quantized gemm on SM80 -- small size gemm version. + */ +template < + typename ElementT, + int block_size, + bool column_wise_blocking, + bool has_offsets> +Status blkq4_small_gemm_sm80( + int m, int n, int k, cudaStream_t stream, + const ElementT* ptr_a, + size_t lda, + const uint8_t* ptr_packed_b, + const ElementT* ptr_scales, + const uint8_t* ptr_offsets, + ElementT* ptr_c, + size_t ldc) { + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + using LayoutQmeta = + typename std::conditional::type; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + const auto meta_shape = cutlass::make_Coord(problem_size.k() / QuantBlocking::kRow, + problem_size.n() / QuantBlocking::kColumn); + if ((problem_size.k() % QuantBlocking::kRow != 0) || + (problem_size.n() % QuantBlocking::kColumn) != 0){ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Partial quantization block in B not supported!"); + } + + // run GEMM + size_t meta_stride = static_cast(LayoutQmeta::packed(meta_shape).stride(0)); + const void* ptr_zp = has_offsets ? ptr_offsets : nullptr; + size_t zp_byte_stride = has_offsets ? meta_stride * sizeof(uint8_t) : size_t(0); + + cutlass::Status status; + + if (k <= 384) { + status = mickey::gemm::device::QuantB4Gemm::run( + stream, problem_size, + ptr_c, ldc * sizeof(ElementT), + ptr_a, lda * sizeof(ElementT), + ptr_packed_b, k * sizeof(uint8_t), + ptr_scales, meta_stride * sizeof(ElementT), + ptr_zp, zp_byte_stride); + } else if (k <= 768) { + status = mickey::gemm::device::QuantB4Gemm::run( + stream, problem_size, + ptr_c, ldc * sizeof(ElementT), + ptr_a, lda * sizeof(ElementT), + ptr_packed_b, k * sizeof(uint8_t), + ptr_scales, meta_stride * sizeof(ElementT), + ptr_zp, zp_byte_stride); + } else if (k < 1536) { + status = mickey::gemm::device::QuantB4Gemm::run( + stream, problem_size, + ptr_c, ldc * sizeof(ElementT), + ptr_a, lda * sizeof(ElementT), + ptr_packed_b, k * sizeof(uint8_t), + ptr_scales, meta_stride * sizeof(ElementT), + ptr_zp, zp_byte_stride); + } else { + status = mickey::gemm::device::QuantB4Gemm::run( + stream, problem_size, + ptr_c, ldc * sizeof(ElementT), + ptr_a, lda * sizeof(ElementT), + ptr_packed_b, k * sizeof(uint8_t), + ptr_scales, meta_stride * sizeof(ElementT), + ptr_zp, zp_byte_stride); + } + + ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + return Status::OK(); +} + + +template +Status +blkq4_fp16_gemm_sm80_dispatch( + int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, + ElementT const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + ElementT const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + ElementT* output_ptr, size_t output_size) { + auto a = gsl::make_span(a_ptr, a_size); + auto weights = gsl::make_span(weights_ptr, weights_size); + auto scales = gsl::make_span(scales_ptr, scales_size); + auto offsets = gsl::make_span(offsets_ptr, offsets_size); + auto output = gsl::make_span(output_ptr, output_size); + + switch (block_size) + { + case 16: + if (column_wise_blocking) { + if (m <= 64 && n < 16384) { + if (offsets.empty()) + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + else + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + } + if (m > 32) { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } + } else { + if (m <= 64 && n < 16384) { + if (offsets.empty()) + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + else + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + } + if (m > 32) { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } + } + break; + + case 32: + if (column_wise_blocking) { + if (m <= 64 && n < 16384) { + if (offsets.empty()) + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + else + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + } + if (m > 32) { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } + } else { + if (m <= 64 && n < 16384) { + if (offsets.empty()) + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + else + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + } + if (m > 32) { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } + } + break; + + case 64: + if (column_wise_blocking) { + if (m <= 64 && n < 16384) { + if (offsets.empty()) + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + else + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + } + if (m > 32) { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } + } else { + if (m <= 64 && n < 16384) { + if (offsets.empty()) + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + else + return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n); + } + if (m > 32) { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); + } + } + break; + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported block size: ", block_size); +} + +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + half const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + half const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + half* output_ptr, size_t output_size); + +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + cutlass::half_t const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + cutlass::half_t const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + cutlass::half_t* output_ptr, size_t output_size); + +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, + onnxruntime::MLFloat16 const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + onnxruntime::MLFloat16 const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + onnxruntime::MLFloat16* output_ptr, size_t output_size); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh index 9ccbe4c4d97a8..b3a325bf806b4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -22,6 +22,17 @@ bool TryMatMul4Bits( int shared_mem_per_block, cudaStream_t stream); +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + ElementT const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + ElementT const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + ElementT* output_ptr, size_t output_size); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index f5c2c6c4e4fdf..a8008e9cdcfa7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -14,7 +14,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -using namespace onnxruntime::cuda; template class MatMulNBits final : public CudaKernel { @@ -24,8 +23,27 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); + int64_t column_wise_quant_blk = 1; + info.GetAttrOrDefault("column_wise_blocking", &column_wise_quant_blk, int64_t(1)); + column_wise_quant_blk_ = column_wise_quant_blk != 0; + info.GetAttrOrDefault("prepacked", &prepack_, int64_t(0)); } +#ifndef USE_ROCM + Status PrepackedGemm([[maybe_unused]] cudaStream_t stream, + [[maybe_unused]] int M, + [[maybe_unused]] const Tensor* a, + [[maybe_unused]] const Tensor* b, + [[maybe_unused]] const Tensor* scales, + [[maybe_unused]] const Tensor* zero_points, + [[maybe_unused]] Tensor* Y) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Prepacked gemm is not supported for MatMulNBits op."); + } +#endif // !USE_ROCM + Status ComputeInternal(OpKernelContext* context) const override; private: @@ -34,6 +52,7 @@ class MatMulNBits final : public CudaKernel { int64_t block_size_; int64_t nbits_; bool column_wise_quant_blk_{true}; + int64_t prepack_{0}; }; } // namespace cuda diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index dea8775c89a30..91cd1fe37d31b 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3431,6 +3431,15 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; + // This is really bad as we expose a property to the user that should never be set by user. + // We have to use this as we perform cuda prepacking during graph optimization phase and + // this is the only way to pass the information to the runtime. + static const char* PrepackProperty_doc = R"DOC( +Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default). +This property should NEVER be set by user. It is set by ONNX Runtime internally during +model loading time. +)DOC"; + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) .SetDomain(kMSDomain) .SinceVersion(1) @@ -3446,6 +3455,8 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi "computation. 4 means input A can be quantized with the same block_size to int8 internally from " "type T1.", AttributeProto::INT, static_cast(0)) + .Attr("column_wise_blocking", "whether to quantize weight columnwise (value 1, default), or rowwise (value 0)", AttributeProto::INT, static_cast(1)) + .Attr("prepacked", PrepackProperty_doc, AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 13620f4d8b3bb..6b2138f39e1b3 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -240,12 +240,6 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, std::string_view op_typ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) -const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { - const auto& attrs = node.GetAttributes(); - const auto iter = attrs.find(attr_name); - return iter == attrs.end() ? nullptr : &iter->second; -} - NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { // sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer const ONNX_NAMESPACE::TensorProto* existing = nullptr; diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 319e055200cca..e27007a6eecfc 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -26,7 +26,27 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Returns the attribute of a Node with a given name. */ -const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name); +static inline const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { + const auto& attrs = node.GetAttributes(); + const auto iter = attrs.find(attr_name); + return iter == attrs.end() ? nullptr : &iter->second; +} + +template +inline Status TryGetNodeAttribute(const Node& node, const std::string& attr_name, T& value); + +template <> +inline Status TryGetNodeAttribute(const Node& node, const std::string& attr_name, int64_t& value) { + const auto* attr = GetNodeAttribute(node, attr_name); + if (!attr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node attribute '", attr_name, "' is not set."); + } + if (!attr->has_i()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node attribute '", attr_name, "' is not an integer."); + } + value = attr->i(); + return Status::OK(); +} /** Add a new initializer to 'graph'. Checks that new_initializer does not already exist in 'graph' before adding it. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h index 52bff7e40dbe3..0b3cf3c0e75fa 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -11,9 +11,19 @@ #pragma once +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + #include "cutlass/cutlass.h" #include "cutlass_ext/q4gemm/device/quantb_gemm.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + namespace onnxruntime { namespace cuda { @@ -127,14 +137,6 @@ struct BlkQ4F16GemmImpl { if constexpr (!kHasQuantOffset) { return cutlass::Status::kErrorNotSupported; } else { - if constexpr (ShapeMMAThreadBlock::kM == 16) { - if (problem_size_.m() > 16) { - // For M > 16, the caller should have picked the - // kernel with bigger M - return cutlass::Status::kErrorNotSupported; - } - } - // Construct Gemm arguments Arguments args{ problem_size_, @@ -172,14 +174,6 @@ struct BlkQ4F16GemmImpl { if constexpr (kHasQuantOffset) { return cutlass::Status::kErrorNotSupported; } else { - if constexpr (ShapeMMAThreadBlock::kM == 16) { - if (problem_size_.m() > 16) { - // For M > 16, the caller should have picked the - // kernel with bigger M - return cutlass::Status::kErrorNotSupported; - } - } - // Construct Gemm arguments Arguments args{ problem_size_, diff --git a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index c81b4967d2719..1492d9846a08e 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -19,6 +19,7 @@ #pragma once #include "core/common/common.h" +#include "core/framework/float16.h" #include "core/util/matrix_layout.h" namespace onnxruntime { @@ -79,6 +80,23 @@ struct BlockwiseQuantization { return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); } + static inline bool weight_dimension_supported(int rows, int columns) { + // prepacking works on a 16x16 block, so the dimensions must be multiples of 16 + if (((rows % 16) != 0) || ((columns % 16) != 0)) { + return false; + } + + // verify the dimensions are multiples of the block size + if (((rows % QuantBlocking::kRow) != 0) || ((columns % QuantBlocking::kColumn) != 0)) { + return false; + } + + // All the above restrictions can be relaxed by adding more logic to the prepack + // and gemm implementation to support edge cases. But for now, these restrictions + // does not affect LLM weight dimensions, so we keep it simple. + return true; + } + /** * @brief Prepack weight matrix to facilitate matrix loading, depending on MMA * instruction layout. @@ -113,10 +131,10 @@ struct BlockwiseQuantization { gsl::span weights, // <- int4 weights, column major gsl::span weights_prepacked // <- int4 prepacked weights tensor, same size buffer ) { - ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && - (rows % QuantBlocking::kRow) == 0 && - (columns % QuantBlocking::kColumn) == 0, - "Does not support odd number of rows or columns!"); +#ifndef NDEBUG + ORT_ENFORCE(weight_dimension_supported(rows, columns), + "This function must be guarded by weight_dimension_supported()!"); +#endif ORT_ENFORCE(weights.size() == size_t(rows * columns / 2), "Weight tensor shape mismatch!"); ORT_ENFORCE(weights_prepacked.size() == weights.size(), @@ -174,6 +192,10 @@ struct BlockwiseQuantization { gsl::span scales, // <- quant scales, column major layout gsl::span scales_prepacked // <- quant scales prepacked, same size buffer ) { +#ifndef NDEBUG + ORT_ENFORCE(weight_dimension_supported(static_cast(rows), static_cast(columns)), + "This function must be guarded by weight_dimension_supported()!"); +#endif auto meta_shape = get_quant_meta_shape(static_cast(rows), static_cast(columns)); ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), "Quantization scale tensor shape mismatch!"); @@ -244,8 +266,11 @@ struct BlockwiseQuantization { gsl::span offsets, // <- quant offsets, int4, column major layout gsl::span offsets_prepacked // <- quant offsets prepacked, double size buffer ) { +#ifndef NDEBUG + ORT_ENFORCE(weight_dimension_supported(static_cast(rows), static_cast(columns)), + "This function must be guarded by weight_dimension_supported()!"); +#endif auto meta_shape = get_quant_meta_shape(static_cast(rows), static_cast(columns)); - ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, "Does not support odd number of rows or columns!"); ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()), @@ -321,5 +346,58 @@ struct BlockwiseQuantization { } }; +static inline bool IsSm80WithWholeBlocks( + int weight_rows, [[maybe_unused]] int weight_cols, + int major, [[maybe_unused]] int minor) { + if (major < 8) { + return false; + } + + // Kernel implementation detail: + // K must be aligned with thread block tile (64) due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. + // To fix this, we need to write our own predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return (weight_rows % 64 == 0); +} + +template +inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { + using Base = BlockwiseQuantization; + if (!Base::weight_dimension_supported(weight_rows, weight_cols)) { + return false; + } + return IsSm80WithWholeBlocks(weight_rows, weight_cols, major, minor); +} + +static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) { + switch (block_size) { + case 16: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 32: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 64: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + } + return false; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index 6e281241a3427..26e878f6120d9 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -132,7 +132,7 @@ struct DummyType{ } CUTLASS_HOST_DEVICE - std::monostate& operator[](int /*idx */) { + std::monostate& operator[]([[maybe_unused]] int idx) { return dummy_; } }; diff --git a/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h b/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h index a0695dbbfd347..97f5773617cf8 100644 --- a/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h +++ b/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h @@ -386,9 +386,9 @@ struct QuantB4Gemm { if constexpr (kSplitK > 1){ // TODO! Use thread block shape - if (params.gemm_k_size_ < WarpShape::kK * kStages * 2) { + if (params.gemm_k_size_ < WarpShape::kK * kStages) { // spliting too small, may not get enough iterations to rampup pipeline - std::cerr << "QuantB4Gemm validation fail: split k too big, each k segment: " << params.gemm_k_size_ << " is smaller than " << (WarpShape::kK * kStages * 2) << std::endl; + std::cerr << "QuantB4Gemm validation fail: split k too big, each k segment: " << params.gemm_k_size_ << " is smaller than " << (WarpShape::kK * kStages) << std::endl; return cutlass::Status::kErrorNotSupported; } } @@ -534,21 +534,6 @@ struct QuantB4Gemm { // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) cutlass::arch::cp_async_wait(); - //__syncthreads(); is this necessary since the loader is warp based? - // if constexpr(kDebugPrintA) { - // if (lane_idx == 0) { - // printf("Prologue, warp: %d, WarpPtr: %p\n", - // warp_idx, a_shared_ptr); - // printf("\n********Dumping the shared memory of Warp %d*******\n\n", warp_idx); - - // for (int i = 0; i < MainLoopSharedBuffer::kASize; i += 8) { - // for (int j = 0; j < 8; ++j) { - // printf("%f, ", float(a_shared_ptr[i + j])); - // } - // printf("\n"); - // } - // } - // } // // Prefix of the Mainloop, pre-loading the double buffer in registers @@ -678,11 +663,10 @@ struct QuantB4Gemm { a_tile_loader.load_to_smem_split(lane_idx, a_smem_write_ptr, next_iter * kAGloadsPerIter + i); } if constexpr (kDebugPrintA) { - const int lane_id = threadIdx.x % 32; - if (lane_id == 0) { - printf("==== A tiles =======\n"); + if (lane_idx == 0) { + printf("===== Warp %d A tiles =======\n", warp_idx); } - const char* const format = (lane_id == 31) ? "%f, %f\n\n" : ((lane_id % 4) == 3) ? "%f, %f\n" : "%f, %f, "; + const char* const format = (lane_idx == 31) ? "%f, %f\n\n" : ((lane_idx % 4) == 3) ? "%f, %f\n" : "%f, %f, "; const ElementT* a_ptr = fragment_a[iter2 % 2].data(); for (int m2_tile = 0; m2_tile < (WarpShape::kM / InstructionShape::kM); ++m2_tile, a_ptr += 8) { printf(format, float(a_ptr[0]), float(a_ptr[1])); diff --git a/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h b/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h index 4a784a1a49109..2422f1b25a366 100644 --- a/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h +++ b/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h @@ -47,10 +47,9 @@ void weightsMinuEight2Half(uint32_t const &weights, // // 1.125 instruction per weight, 9 instructions in total. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 500)) uint32_t* b32s = reinterpret_cast(dest.data()); const uint32_t high_8s = weights >> 8; - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 500)) asm volatile( " lop3.b32 %0, %4, 0x000f000f, %6, 0xea;\n" " lop3.b32 %1, %4, 0x00f000f0, %7, 0xea;\n" @@ -67,6 +66,8 @@ void weightsMinuEight2Half(uint32_t const &weights, "r"(0x64086408)); #else assert(false); + (void)(weights); + (void)(dest); #endif } @@ -331,6 +332,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme // only one scale/offset, so the block size cannot be smaller than 16. static_assert(QuantBlocking::kRow % 16 == 0); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) const int meta_k = k_iter / (QuantBlocking::kRow / 16); half const* scales = reinterpret_cast(frag_scales.data() + meta_k * kMetaFragSize); [[maybe_unused]] half const* offsets = nullptr; @@ -391,6 +393,14 @@ struct QuantBScaleLoader, WarpShape_, Eleme } } } +#else + assert(false); + (void)(k_iter); + (void)(frag_pack_b); + (void)(frag_scales); + (void)(frag_offsets); + (void)(frag_b); +#endif // __CUDA_ARCH__ } }; @@ -611,6 +621,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme constexpr int kPackedBKStride = PackedBSize / kPackedBNTiles; static_assert(kPackedBKStride * kPackedBNTiles == PackedBSize); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) // Row-wise quantization, every row has its own scale/offset CUTLASS_PRAGMA_UNROLL for (int nn = 0; nn < (WarpShape::kN / 16); ++nn) { @@ -680,6 +691,14 @@ struct QuantBScaleLoader, WarpShape_, Eleme } } } +#else + assert(false); + (void)(k_iter); + (void)(frag_pack_b); + (void)(frag_scales); + (void)(frag_offsets); + (void)(frag_b); +#endif // __CUDA_ARCH__ } }; diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.cc b/onnxruntime/core/optimizer/gpu_ops_prepack.cc new file mode 100644 index 0000000000000..aed5817c21868 --- /dev/null +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.cc @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Module Abstract: +// This module defines the logic for prepacking weights +// (aka Initializers in onnxruntime) of GPU operators. +// Unlike CPU operators, overriding the PrePack() method +// of class OpKernel results in GPU memory fragmentation. +// So we try to rewrite the weight tensors during graph +// optimization phase to avoid this problem +// +// Unfortunately, there are still some seriouse problems +// with this approach: +// 1. Rewriting of the initializer tensors is restricted +// by operator shape inferencing rules. For example, +// there are 3 initializers for MatMulNBits, +// we can't combine them into a single initializer. +// And we have to make sure the operator's shape inference +// logic does NOT verify the initializer's shape. +// 2. These rewriting logic is tightly coupled to each GPU +// operators. It really should be defined together with +// these operators, instead of defining them in a complete +// different module. +// 3. The logic of prepacking depends on underlying GPU +// hardware. Currently this part is hard-coded for SM80. + +#if defined(USE_CUDA) && !defined(USE_ROCM) + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/gpu_ops_prepack.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_utils.h" + +#include "blk_q4/f16_prepack_sm80.h" + +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_execution_provider_info.h" + +namespace onnxruntime { + +extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); + +/** + * @brief Read initialized tensor from protobuf, and store it in ort_value. + * Keep in mind that ort_value is the owner of the tensor memory after calling this function. + */ +inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { + const ONNX_NAMESPACE::TensorProto* tensor_proto; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), + "Missing initializer for ", arg->Name()); + + return utils::TensorProtoToOrtValue( + Env::Default(), graph.ModelPath(), *tensor_proto, + std::make_shared(), ort_value); +} + +template +inline gsl::span make_span(std::string& str) { + return gsl::make_span(reinterpret_cast(str.data()), str.size() / sizeof(T)); +} + +// +// Prepacking logic specific to MatMulNBits on sm80 +// + +static inline bool IsNodeMatMulNbitsFp16(const Node& node) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain)) { + return false; + } + const auto* acts = node.InputDefs()[0]; + if (acts == nullptr || acts->Type() == nullptr || acts->Type()->find("float16") == std::string::npos) { + return false; + } + return true; +} + +template +void Sm80BlkQ4PrepackT( + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + MLFloat16, + block_size, + 4, + column_quant_blk>; + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + + packed_w.resize(SafeInt(q_weight_shape.product() * sizeof(uint8_t))); + Base::prepack_weights( + rows, columns, weights, + make_span(packed_w)); + + packed_scales.resize(SafeInt(meta_shape.product() * sizeof(MLFloat16))); + Base::prepack_quant_scales( + rows, columns, scales, + make_span(packed_scales)); + + if (!zp.empty()) { + packed_zp.resize(SafeInt(meta_shape.product() * sizeof(uint8_t))); + Base::prepack_quant_offsets( + rows, columns, zp, + make_span(packed_zp)); + } +} + +void Sm80BlkQ4Prepack( + int block_size, bool column_quant_blk, + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { + switch (block_size) { + case 16: + if (column_quant_blk) { + Sm80BlkQ4PrepackT<16, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } else { + Sm80BlkQ4PrepackT<16, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } + break; + case 32: + if (column_quant_blk) { + Sm80BlkQ4PrepackT<32, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } else { + Sm80BlkQ4PrepackT<32, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } + break; + case 64: + if (column_quant_blk) { + Sm80BlkQ4PrepackT<64, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } else { + Sm80BlkQ4PrepackT<64, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } + break; + default: + ORT_THROW("Unsupported block size: ", block_size); + } +} + +/** + *@brief Prepack weights of the operator MatMulNBits. + * The caller should make sure the node is of type MatMulNBits. + */ +Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { + modified = false; + int64_t att_i; + + // + // Verify prepacking is needed and supported + // + Status status = graph_utils::TryGetNodeAttribute(node, "prepacked", att_i); + bool prepacked = status.IsOK() ? att_i != 0 : false; + if (prepacked) { + return Status::OK(); // already prepacked, nothing to do + } + + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "bits", att_i)); + int nbits = SafeInt(att_i); + if (nbits != 4) { + return Status::OK(); // only support 4 bits for now + } + + // A single dimension can not exceed 2G yet. + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "K", att_i)); + int k = SafeInt(att_i); + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "N", att_i)); + int n = SafeInt(att_i); + + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "block_size", att_i)); + int block_size = SafeInt(att_i); + + status = graph_utils::TryGetNodeAttribute(node, "column_wise_blocking", att_i); + bool column_wise_quant_blk = status.IsOK() ? att_i != 0 : true; + + auto* provider_info = TryGetProviderInfo_CUDA(); + ORT_ENFORCE(provider_info != nullptr, "Failed to query CUDA provider info while prepacking cuda operators."); + int major, minor; + ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, + "Failed to query CUDA device version while prepacking cuda operators."); + + if (!onnxruntime::cuda::BlkQuantGemmSm80Supported(block_size, column_wise_quant_blk, k, n, major, minor)) { + return Status::OK(); // not supported + } + + // + // Verification passed, start prepacking + // + auto& node_name = node.Name(); + auto& mutable_input_defs = node.MutableInputDefs(); + if (mutable_input_defs.size() < 3 || mutable_input_defs.size() > 4) { + return Status::OK(); // not supported + } + + NodeArg* old_weights_arg = mutable_input_defs[1]; + NodeArg* old_scales_arg = mutable_input_defs[2]; + NodeArg* old_zp_arg = nullptr; + + // holders of the packed weight tensor memory + std::string packed_weights; + std::string packed_scales; + std::string packed_zp; + + { + // owners of the weight tensor memory, keep around until consumed by the prepacking function + OrtValue weights_val; + OrtValue scales_val; + OrtValue zp_val; + + ORT_RETURN_IF_ERROR(GetOrtValue(old_weights_arg, graph, weights_val)); + const gsl::span weights = weights_val.GetMutable()->DataAsSpan(); + + ORT_RETURN_IF_ERROR(GetOrtValue(old_scales_arg, graph, scales_val)); + const gsl::span scales = scales_val.GetMutable()->DataAsSpan(); + + gsl::span zp; + if (mutable_input_defs.size() > 3) { + old_zp_arg = mutable_input_defs[3]; + if (old_zp_arg != nullptr && old_zp_arg->Exists()) { + ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val)); + Tensor* zp_tensor_ptr = zp_val.GetMutable(); + if (!zp_tensor_ptr->IsDataType()) { + return Status::OK(); // not supported + } + zp = zp_tensor_ptr->DataAsSpan(); + } + } + + Sm80BlkQ4Prepack(block_size, column_wise_quant_blk, k, n, weights, scales, zp, packed_weights, packed_scales, packed_zp); + +#if 0 + // debug print if prepacked tests fail + std::cout << " ====== packed weight ====== " << std::endl << std::hex; + const gsl::span packed_weights_span(reinterpret_cast(packed_weights.data()), packed_weights.size()); + for (int r = 0; r < k; r++) { + for (int c = 0; c < n/2; c++) { + std::cout << std::setw(2) << std::setfill('0') << static_cast(packed_weights_span[c * k + r]) << " "; + } + std::cout << std::endl; + } + std::cout << std::dec; +#endif + } + + // + // write packed weight tensor to node parameters + // + ONNX_NAMESPACE::TensorProto packed_weights_proto; + packed_weights_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_weight")); + packed_weights_proto.add_dims(packed_weights.size() / sizeof(uint8_t)); + packed_weights_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType()); + packed_weights_proto.set_raw_data(std::move(packed_weights)); + NodeArg& packed_weights_arg = graph_utils::AddInitializer(graph, packed_weights_proto); + graph.RemoveConsumerNode(old_weights_arg->Name(), &node); + mutable_input_defs[1] = &packed_weights_arg; + graph.AddConsumerNode(packed_weights_arg.Name(), &node); + + ONNX_NAMESPACE::TensorProto packed_scales_proto; + packed_scales_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_scales")); + packed_scales_proto.add_dims(packed_scales.size() / sizeof(MLFloat16)); + packed_scales_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType()); + packed_scales_proto.set_raw_data(std::move(packed_scales)); + NodeArg& packed_scales_arg = graph_utils::AddInitializer(graph, packed_scales_proto); + graph.RemoveConsumerNode(old_scales_arg->Name(), &node); + mutable_input_defs[2] = &packed_scales_arg; + graph.AddConsumerNode(packed_scales_arg.Name(), &node); + + if (!packed_zp.empty()) { + ONNX_NAMESPACE::TensorProto packed_zp_proto; + packed_zp_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_zp")); + packed_zp_proto.add_dims(packed_zp.size() / sizeof(uint8_t)); + packed_zp_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType()); + packed_zp_proto.set_raw_data(std::move(packed_zp)); + NodeArg& packed_zp_arg = graph_utils::AddInitializer(graph, packed_zp_proto); + graph.RemoveConsumerNode(old_zp_arg->Name(), &node); + mutable_input_defs[3] = &packed_zp_arg; + graph.AddConsumerNode(packed_zp_arg.Name(), &node); + } + + node.AddAttribute("prepacked", static_cast(1)); + modified = true; + return Status::OK(); +} + +Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + // int fused_count = 0; + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; // node was removed as part of an earlier fusion + + Node& node = *p_node; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (node.GetExecutionProviderType() != onnxruntime::kCudaExecutionProvider) { + continue; // only interested in CUDA nodes + } + + // Run prepack if the node is MatMulNBits. + // When we have more operators to support, we should use a map to dispatch the prepack function + // instead of adding a whole bunch of if branches here. + if (IsNodeMatMulNbitsFp16(node)) { + bool packed = false; + ORT_RETURN_IF_ERROR(PackMatMulNBitsFp16(node, graph, packed)); + modified |= packed; + continue; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime + +#endif // USE_CUDA && !USE_ROCM diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.h b/onnxruntime/core/optimizer/gpu_ops_prepack.h new file mode 100644 index 0000000000000..0beecde4021d9 --- /dev/null +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_CUDA) && !defined(USE_ROCM) + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class AttentionFusion +Rewrite graph fusing attention subgraph to a single Attention node. +*/ +class GpuOpsPrepack : public GraphTransformer { + public: + GpuOpsPrepack() noexcept + : GraphTransformer("GpuOpsPrepack", InlinedHashSet{onnxruntime::kCudaExecutionProvider}) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: +}; + +} // namespace onnxruntime + +#endif // USE_CUDA && !USE_ROCM diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4298551aec412..a8404a6587859 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -6,6 +6,7 @@ #include #include +#include "core/framework/execution_providers.h" #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/matmul_nbits_fusion.h" #include "core/optimizer/nhwc_transformer.h" @@ -43,6 +44,7 @@ #include "core/optimizer/gemm_activation_fusion.h" #include "core/optimizer/gemm_sum_fusion.h" #include "core/optimizer/gemm_transpose_fusion.h" +#include "core/optimizer/gpu_ops_prepack.h" #include "core/optimizer/identical_children_consolidation.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/label_encoder_fusion.h" @@ -186,8 +188,11 @@ std::unique_ptr GenerateRuleBasedGraphTransformer( InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, - const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const ExecutionProviders& execution_providers, /*required by constant folding*/ const InlinedHashSet& rules_and_transformers_to_disable) { + // CPU EP required to run constant folding + const IExecutionProvider& cpu_execution_provider = *execution_providers.Get(onnxruntime::kCpuExecutionProvider); + InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -387,7 +392,16 @@ InlinedVector> GenerateTransformers( // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); -#endif + +#if defined(USE_CUDA) && !defined(USE_ROCM) + // Cuda weight prepacking. + auto* cuda_ep = execution_providers.Get(onnxruntime::kCudaExecutionProvider); + if (cuda_ep != nullptr) { + transformers.emplace_back(std::make_unique()); + } +#endif // USE_CUDA && !USE_ROCM + +#endif // !defined(DISABLE_CONTRIB_OPS) } break; @@ -408,8 +422,9 @@ InlinedVector> GenerateTransformersForMinimalB TransformerLevel level, const SessionOptions& session_options, const SatApplyContextVariant& apply_context, - const IExecutionProvider& cpu_execution_provider, + const ExecutionProviders& execution_providers, const InlinedHashSet& rules_and_transformers_to_disable) { + const auto& cpu_execution_provider = *execution_providers.Get(onnxruntime::kCpuExecutionProvider); InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 103c79c93b2ca..ca794108d43c6 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -71,7 +71,7 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { return nullptr; } - OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) override { + OrtStatus* GetCurrentGpuDeviceId(_Out_ int* device_id) override { auto cuda_err = cudaGetDevice(device_id); if (cuda_err != cudaSuccess) { return CreateStatus(ORT_FAIL, "Failed to get device id."); @@ -79,6 +79,23 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { return nullptr; } + OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) override { + int device_id; + auto cuda_err = cudaGetDevice(&device_id); + if (cuda_err != cudaSuccess) { + return CreateStatus(ORT_FAIL, "Failed to get device id."); + } + cudaDeviceProp prop; + cuda_err = cudaGetDeviceProperties(&prop, device_id); + if (cuda_err != cudaSuccess) { + return CreateStatus(ORT_FAIL, "Failed to get device properties."); + } + *major = prop.major; + *minor = prop.minor; + + return nullptr; + } + std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) override { return std::make_unique(device_id, name); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/core/providers/cuda/cuda_provider_factory.h index 4d5ef658f6be0..320eb4ed82cfc 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.h +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.h @@ -22,7 +22,8 @@ class NvtxRangeCreator; struct ProviderInfo_CUDA { virtual OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) = 0; - virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; + virtual OrtStatus* GetCurrentGpuDeviceId(_Out_ int* device_id) = 0; + virtual OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) = 0; virtual std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateCUDAPinnedAllocator(const char* name) = 0; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 3ef6490a56ded..1e81f5b2bd592 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1613,15 +1613,15 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, - const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { + const InlinedHashSet& optimizers_to_disable, const ExecutionProviders& execution_providers) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); level <= static_cast(session_options.graph_optimization_level); ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( - static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable); + static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, + execution_providers, optimizers_to_disable); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2007,9 +2007,8 @@ common::Status InferenceSession::Initialize() { *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, execution_providers_)); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3159,7 +3158,6 @@ common::Status InferenceSession::AddPredefinedTransformers( TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); if (graph_optimization_level >= level) { @@ -3170,7 +3168,7 @@ common::Status InferenceSession::AddPredefinedTransformers( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; if (use_full_build_optimizations) { - return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, + return optimizer_utils::GenerateTransformers(level, session_options_, execution_providers_, optimizers_to_disable_); } else { const auto sat_context = @@ -3179,8 +3177,8 @@ common::Status InferenceSession::AddPredefinedTransformers( ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{ record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; - return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - optimizers_to_disable_); + return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, + execution_providers_, optimizers_to_disable_); } }(); diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index 43843da3fb96e..b4cf8cf518564 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -428,18 +428,18 @@ class MatrixRef { /// Returns a reference to the element at a given Coord ORT_FORCEINLINE Reference at(MatCoord const& coord) const { - return data_[offset(coord)]; + return data_[static_cast(offset(coord))]; } ORT_FORCEINLINE Reference at(int row, int col) const { - return data_[offset(make_Position(row, col))]; + return data_[static_cast(offset(make_Position(row, col)))]; } /// Returns a reference to the element at a given Coord ORT_FORCEINLINE Reference operator[](MatCoord const& coord) const { - return data_[offset(coord)]; + return data_[static_cast(offset(coord))]; } }; diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index 942b1c4d2c2ad..4e9df0369c011 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -15,12 +15,163 @@ #pragma once -#include "core/util/matrix_layout.h" +#include +#include "core/mickey/blk_q4/f16_prepack_sm80.h" #include "core/common/common.h" namespace onnxruntime { namespace test { +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t vscale = dis(gen); + uint32_t m = (vscale % 63) + 1; + uint32_t e = (vscale >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr (has_offsets) { + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr (has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "%f~%d-%d|%f, ", dequant, w, offset, scale); + } + // fprintf(stderr, "\n"); + } +} + static inline void sm80_prepack_weights_ref( int rows, int columns, diff --git a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc new file mode 100644 index 0000000000000..949fbef3da689 --- /dev/null +++ b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" +#include "core/mlas/inc/mlas.h" +#include "core/graph/graph.h" +#include "core/optimizer/initializer.h" + +#include "core/mlas/inc/mlas_q4.h" +#include "core/providers/cuda/cuda_provider_factory_creator.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_execution_provider_info.h" + +#include "test/cuda_host/blkq4_fp16_quant_sm80.h" +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +namespace onnxruntime { + +extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); + +namespace test { +#ifndef DISABLE_CONTRIB_OPS + +std::shared_ptr LoadCudaEp() { + try { + OrtCUDAProviderOptions cuda_options; + auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options); + if (!factory) { + return nullptr; + } + return factory->CreateProvider(); + } catch (const ::onnxruntime::OnnxRuntimeException& e) { + std::cerr << "LoadCudaEp: " << e.what() << std::endl; + return nullptr; + } +} + +/** + * @brief Testing helper for GPU prepacking logic in the graph transformer. + * This is an modification of the TransformerTester function from + * onnxruntime/test/optimizer/graph_transform_test_builder.cc + * with: + * - the addition of cuda execution provider in the session. + * - a different location for the model checker, right after session initialization + * as the initializers will be deleted during session run. + */ +void GpuPrepackTester( + const std::shared_ptr& cuda_ep, + const std::function& build_test_case, + const std::function& check_transformed_graph, + TransformerLevel baseline_level, + TransformerLevel target_level, + int opset_version = 12, + double per_sample_tolerance = 0.001, + double relative_per_sample_tolerance = 0.001, + const std::function& add_session_options = {}, + const InlinedHashSet& disabled_optimizers = {}) { + // Build the model for this test. + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; + Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + ASSERT_TRUE(build_test_case); + build_test_case(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + auto run_model = [&](TransformerLevel level, std::vector& fetches) { + SessionOptions session_options; + session_options.graph_optimization_level = level; + if (level == target_level) { + // we don't really need the model file, but it seems to be the only way to keep the + // transformed initializers so that they can be verified. + session_options.optimized_model_filepath = + ToPathString("gpu_prepack_test_model_opt_level_" + std::to_string(static_cast(level)) + ".onnx"); + } + if (add_session_options) { + add_session_options(session_options); + } + InferenceSessionWrapper session{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(cuda_ep)); + + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + if (!disabled_optimizers.empty()) { + ASSERT_STATUS_OK(session.FilterEnabledOptimizers(InlinedHashSet{disabled_optimizers})); + } + + ASSERT_STATUS_OK(session.Initialize()); + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, + helper.feeds_, + helper.output_names_, + &fetches)); + + if (level == target_level) { + if (check_transformed_graph) { + check_transformed_graph(session); + } + } + }; + + std::vector baseline_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(baseline_level, baseline_fetches)); + + std::vector target_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(target_level, target_fetches)); + + size_t num_outputs = baseline_fetches.size(); + ASSERT_EQ(num_outputs, target_fetches.size()); + + for (size_t i = 0; i < num_outputs; i++) { + std::pair ret = + CompareOrtValue(target_fetches[i], + baseline_fetches[i], + per_sample_tolerance, + relative_per_sample_tolerance, + false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } +} + +inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { + const ONNX_NAMESPACE::TensorProto* tensor_proto; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), + "Missing initializer for ", arg->Name()); + + return utils::TensorProtoToOrtValue( + Env::Default(), graph.ModelPath(), *tensor_proto, + std::make_shared(), ort_value); +} + +template +void MatMulQ4Test(int M, int N, int K, const std::shared_ptr& cuda_ep) { + // + // Type definitions + // + using ElementT = MLFloat16; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + columnwise_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + // + // Generate random inputs + // + const auto q_weight_shape = Base::get_quant_weights_shape(K, N); + const auto meta_shape = Base::get_quant_meta_shape(K, N); + + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + blkq4_weights_gen( + K, N, dequants, q_weights, q_scales, q_zp); + + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( + q_weights, q_weight_shape); + MatrixRef tensor_scale( + q_scales, meta_shape); + MatrixRef tensor_offset; + if constexpr (has_offsets) { + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + tensor_offset = MatrixRef(q_zp, zp_shape); + } + + // Compute prepacked weights + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(K, N / 2)); + onnxruntime::test::sm80_prepack_weights_ref(K, N, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_scales_ref( + K, N, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } + } + + std::vector packed_zp_ref; + if constexpr (has_offsets) { + packed_zp_ref.resize(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_offsets_ref( + K, N, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } + } + } + + auto build_test_case = [&](ModelTestBuilder& builder) { + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise_blocking, K, N, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + auto* input_arg = builder.MakeInput({M, K}, MLFloat16(-2.0f), MLFloat16(2.0f)); + constexpr bool print_input_a = false; + if constexpr (print_input_a) { + const auto& act_name = input_arg->Name(); + OrtValue act_val = builder.feeds_[act_name]; + const gsl::span act_data = act_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(act_data.size(), M * K); + fprintf(stderr, "====== act_data ======:\n"); + for (int act_row = 0; act_row < M; act_row++) { + for (int act_col = 0; act_col < K; act_col++) { + fprintf(stderr, "%f, ", static_cast(act_data[act_row * K + act_col])); + } + fprintf(stderr, "\n"); + } + } + + auto* output_arg = builder.MakeOutput(); + auto* weight_arg = builder.MakeInitializer({q_weight_shape[1], q_weight_shape[0]}, q_weights); + auto* scale_arg = builder.MakeInitializer({static_cast(q_scales.size())}, q_scales); + + std::vector input_args{input_arg, weight_arg, scale_arg}; + if constexpr (has_offsets) { + auto* zero_point_arg = builder.MakeInitializer({static_cast(q_zp.size())}, q_zp); + input_args.push_back(zero_point_arg); + } else { + ASSERT_TRUE(q_zp.empty()); + } + Node& node = builder.AddNode("MatMulNBits", input_args, {output_arg}, kMSDomain); + node.AddAttribute("K", static_cast(K)); + node.AddAttribute("N", static_cast(N)); + node.AddAttribute("block_size", static_cast(block_size)); + node.AddAttribute("bits", static_cast(4)); + node.AddAttribute("column_wise_blocking", static_cast(columnwise_blocking)); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const auto& graph = session.GetGraph(); + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + ASSERT_EQ(node.Domain(), kMSDomain); + ASSERT_EQ(node.GetAttributes().at("prepacked").i(), 1); + { + // Verify prepacked weights + OrtValue packed_w_val; + ASSERT_STATUS_OK(GetOrtValue(node.InputDefs()[1], graph, packed_w_val)); + const gsl::span weights_data = packed_w_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(weights_data.size(), packed_w_ref.size()); + for (size_t i = 0; i < packed_w_ref.size(); ++i) { + int expected = packed_w_ref[i]; + int found = weights_data[i]; + ASSERT_EQ(expected, found) << "prepacked weight mismatch index i = " << i << " shape[" << K << "," << N / 2 << "]"; + } + } + { + // Verify prepacked scales + OrtValue packed_s_val; + ASSERT_STATUS_OK(GetOrtValue(node.InputDefs()[2], graph, packed_s_val)); + const gsl::span scales_data = packed_s_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(scales_data.size(), packed_scales_ref.size()); + for (size_t i = 0; i < packed_scales_ref.size(); ++i) { + float expected = packed_scales_ref[i]; + float found = scales_data[i]; + ASSERT_EQ(expected, found) << "prepacked scale mismatch index i = " << i << " shape[" << meta_shape[0] << "," << meta_shape[1] << "]"; + } + } + if constexpr (has_offsets) { + // Verify prepacked zero points + OrtValue packed_z_val; + ASSERT_STATUS_OK(GetOrtValue(node.InputDefs()[3], graph, packed_z_val)); + const gsl::span offsets_data = packed_z_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(offsets_data.size(), packed_zp_ref.size()); + for (size_t i = 0; i < packed_zp_ref.size(); ++i) { + int expected = packed_zp_ref[i]; + int found = offsets_data[i]; + ASSERT_EQ(expected, found) << "prepacked zero-point mismatch index i = " << i << " shape[" << meta_shape[0] << "," << meta_shape[1] << "]"; + } + } else { + ASSERT_LE(node.InputDefs().size(), static_cast(3)); + } + } + } + }; + + GpuPrepackTester(cuda_ep, + build_test_case, + check_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); +} + +TEST(GpuOpPrepackTests, MatmulNBits) { + std::shared_ptr provider = LoadCudaEp(); + if (!provider) { + GTEST_SKIP() << "Skipping tests when CUDA EP is not available"; + } + + // + // Currently these tests only work on sm_80. Going forward, however, + // we need a better solution when we may have different tests for different + // hardware. + // + auto* provider_info = TryGetProviderInfo_CUDA(); + int major, minor; + ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, + "Failed to query CUDA device version while prepacking cuda operators."); + if (major < 8) { + GTEST_SKIP() << "Skipping tests when CUDA EP is not sm_80"; + } + + // + // GpuPrepackTester function implements two different verifications. + // First is the hook check_graph, which we use to verify the prepacked weights, scales and zero points. + // Second is the comparison of the outputs of the model with and without prepacking, this actually + // doubles as a verification of kernel correctness and prepacking correctness. + // + // We do have other sets of tests for the prepack and kernel correctness, defined in + // onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc + // + // What we are doing here is to verify we correctly connected the prepacking logic in + // the graph transformer, and that the prepacked weights, scales and zero points are correctly + // passed to the kernel in MatMulNBits cuda op. Plus the redundant verifications allows us to + // locate the problem more easily. + // + // So we don't need to test all the combinations here, just a few representative ones. + // + + std::cout << "Testing MatMulQ4Test<64, true, true>(4, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, true, true>(4, 128, 64, provider); + std::cout << "Testing MatMulQ4Test<64, false, true>(4, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, false, true>(4, 128, 64, provider); + std::cout << "Testing MatMulQ4Test<64, true, false>(8, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, true, false>(8, 128, 64, provider); + std::cout << "Testing MatMulQ4Test<64, false, false>(8, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, false, false>(8, 128, 64, provider); + + std::cout << "Testing MatMulQ4Test<32, true, true>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, true, true>(16, 64, 128, provider); + std::cout << "Testing MatMulQ4Test<32, true, false>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, true, false>(16, 64, 128, provider); + std::cout << "Testing MatMulQ4Test<32, false, true>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, false, true>(16, 64, 128, provider); + std::cout << "Testing MatMulQ4Test<32, false, false>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, false, false>(16, 64, 128, provider); + + std::cout << "Testing MatMulQ4Test<16, true, true>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, true, true>(32, 96, 128, provider); + std::cout << "Testing MatMulQ4Test<16, true, false>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, true, false>(32, 96, 128, provider); + std::cout << "Testing MatMulQ4Test<16, false, true>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, false, true>(32, 96, 128, provider); + std::cout << "Testing MatMulQ4Test<16, false, false>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, false, false>(32, 96, 128, provider); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f83fb8238ff61..fd061f8a417d0 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -14,6 +14,7 @@ #include "core/common/span_utils.h" #include "core/framework/data_types.h" +#include "core/framework/execution_providers.h" #include "core/framework/ort_value.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -826,12 +827,15 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map e = - std::make_unique(CPUExecutionProviderInfo()); + std::shared_ptr e = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(e))); bool has_constant_folding = false; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_options, + execution_providers, {}); for (auto& transformer : transformers) { if (transformer->Name() == "ConstantFolding") { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1)); @@ -4636,11 +4640,14 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { } static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_options) { - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + std::shared_ptr e = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(e))); bool has_gelu_approximation = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, + execution_providers, {}); for (auto& transformer : transformers) { if (transformer->Name() == "GeluApproximation") { has_gelu_approximation = true; @@ -4653,9 +4660,13 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt // Test session option configuration for DoubleQDQPairsRemover TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { - std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); + std::shared_ptr cpu_ep = std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep))); + bool has_double_qdq_remover = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, + execution_providers, {}); for (auto& transformer : transformers) { if (transformer->Name() == "DoubleQDQPairsRemover") { has_double_qdq_remover = true; diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 73c8b3f119103..0a0e61f8f39b1 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -170,6 +170,7 @@ void TransformerTester(const std::function& buil add_session_options(session_options); } InferenceSessionWrapper session{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); if (transformer) { ASSERT_STATUS_OK(session.RegisterGraphTransformer(std::move(transformer), level)); diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 66b74641e41d3..919d08cd628ea 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -35,10 +35,15 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l1_transformer = "ConstantFolding"; std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + std::shared_ptr e = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + auto status = execution_providers.Add(kCpuExecutionProvider, std::move(e)); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, execution_providers); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, + execution_providers, disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -61,8 +66,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, execution_providers); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, + execution_providers, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index b95e093e41eab..c8fde40a8ee01 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -33,6 +33,9 @@ void testPrepack(int rows, int columns) { 4, col_blocking>; + EXPECT_TRUE(Base::weight_dimension_supported(rows, columns)) + << "Test setup problem, unsupported weight dimension: [" << rows << ", " << columns << "]"; + using QuantBlocking = typename Base::QuantBlocking; using ElementW = typename Base::ElementW; using LayoutWPack = typename Base::LayoutWPack; @@ -276,8 +279,8 @@ TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 64); onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 64); - onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 192); - onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 192); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(64, 96, 192); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(64, 96, 192); onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(256, 672, 576); onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(256, 672, 576); @@ -316,8 +319,8 @@ TEST(BlkQ4_GEMM, Sm80SmallMTest) { onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576); onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576); - onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(16, 1024, 576); - onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(32, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(32, 1024, 576); onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, false>(16, 672, 576); onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, true>(16, 672, 576); diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index f5600ca9885a3..338070e73ebce 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -29,6 +29,9 @@ #include "core/common/common.h" +#include "blkq4_fp16_gemm_sm80.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" + namespace onnxruntime { namespace cuda { namespace test { @@ -200,6 +203,8 @@ void run_blkq4_gemm(int m, int n, int k) { block_size, 4, column_wise_blocking>; + ORT_ENFORCE(PrepackT::weight_dimension_supported(k, n), + "Test setup problem, unsupported weight dimension: [", k, ", ", n, "]"); thrust::host_vector packed_w(q_weight_shape.product()); PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); @@ -256,19 +261,16 @@ void run_blkq4_gemm(int m, int n, int k) { tensor_d.sync_device(); // run GEMM - cutlass::Status status; - if constexpr (has_offsets) { - status = GemmRunner::run( - nullptr, problem_size, tensor_a.device_ref(), ref_W, - ref_scales, ref_zp, - tensor_c.device_ref(), tensor_d.device_ref()); - } else { - status = GemmRunner::run( - nullptr, problem_size, tensor_a.device_ref(), ref_W, - ref_scales, - tensor_c.device_ref(), tensor_d.device_ref()); - } - ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + ORT_THROW_IF_ERROR( + onnxruntime::contrib::cuda::blkq4_fp16_gemm_sm80_dispatch( + block_size, column_wise_blocking, + m, n, k, + nullptr, + tensor_a.device_data(), tensor_a.size(), + d_packed_w.data().get(), d_packed_w.size(), + d_packed_scales.data().get(), d_packed_scales.size(), + d_packed_zp.data().get(), d_packed_zp.size(), + tensor_d.device_data(), tensor_d.size())); // Running reference kernel using ElementInputB = ElementInputA; diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 72357ec7e02d2..e6c86f41300af 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -5,7 +5,6 @@ // extra code in the core of CUDA EP and that code may // 1. slow down performance critical applications and // 2. increase binary size of ORT. - #include "gtest/gtest.h" #include diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc index d8384b432786b..d148fd6aa6ef0 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc @@ -43,6 +43,12 @@ struct ProviderInfo_CUDA_TestImpl : ProviderInfo_CUDA { return nullptr; } + virtual OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) override { + *major = 0; + *minor = 0; + return nullptr; + } + std::unique_ptr CreateCUDAAllocator(int16_t, const char*) override { return nullptr; } diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..2e8d85f38b2db 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -21,10 +21,15 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { std::string l1_transformer = "ConstantFolding"; std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + std::shared_ptr cpu_ep = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + auto status = execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep)); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, execution_providers); + auto filtered_transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, {}, execution_providers, disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +52,9 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, execution_providers); + filtered_transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level2, {}, execution_providers, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif }