From de14fbc63beafac2eb52a2398bc72b13b384262f Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 23 Feb 2024 13:50:48 +0800 Subject: [PATCH] cuda kernel ready --- .../cuda/quantization/dequantize_blockwise.cu | 159 ++++++++++++++---- .../quantization/dequantize_blockwise.cuh | 6 +- .../cuda/quantization/matmul_nbits.cc | 66 +++++--- .../core/graph/contrib_ops/contrib_defs.cc | 7 +- .../test/contrib_ops/matmul_4bits_test.cc | 43 ++++- 5 files changed, 209 insertions(+), 72 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e221..ba8e8511fbb79 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -56,41 +58,94 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + //T __shared__ zero_points_after_reorder[];//K + //T __shared__ scales_after_reorder[]; // K + //const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx[kb_idx * block_size + i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = block_id / scales_shape_x; + int kb_idx = block_id % scales_shape_x; + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr(std::is_same_v) { + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +153,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + //static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd893..580b5087f3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 12cf5c83def33..fcecce8ce9999 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -8,6 +8,7 @@ // #include "matmul_nbits.h" +#include #include "core/common/status.h" #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" @@ -25,11 +26,13 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -44,33 +47,50 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + if (!is_4bit_done) { int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); auto* b_data = b_data_ptr.get(); if (column_wise_quant_blk_) { // column-wise block - ORT_RETURN_IF_ERROR(Dequantize4Bits( - reinterpret_cast(b_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(K_padded), - SafeInt(N_), - SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + if ((zero_points && zero_points->IsDataType())) { + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const CudaT*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } } else { // row-wise block K_padded = K_; @@ -79,7 +99,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const uint8_t*)zero_points_data, SafeInt(block_size_), column_wise_quant_blk_, SafeInt(K_), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 354345327f5da..bd0f82d2b231c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3368,11 +3368,12 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) - .Input(4, "g_idx", "group_idx for gptq", "T2", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.") - .TypeConstraint("T3", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)", "tensor(float16)"}, "Constrain quantized zero point types to uint8/uint32/int32/float16.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..be8383e64e89b 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #ifndef ORT_MINIMAL_BUILD #include "core/common/span_utils.h" @@ -66,7 +67,8 @@ void QuantizeDequantize(std::vector& raw_vals, } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) { + bool has_zeropoint, bool use_float16, bool has_g_idx = false, bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { + zp_is_4bit = zp_is_4bit|has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); @@ -118,7 +120,32 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + if (zp_is_4bit){ + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes*2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size()-1; + while(zp_f.size() != q_scale_size){ + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size/N+1; + } + + test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + std::vector g_idx(K); + for (int64_t i = 0; i < K; i++) { + g_idx[i] = gsl::narrow(i/block_size); + } + test.AddInput("g_idx", {static_cast(K)}, g_idx, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -172,8 +199,10 @@ TEST(MatMulNBits, Float16) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - RunTest(M, N, K, block_size, 0, false, true); - RunTest(M, N, K, block_size, 0, true, true); + for (auto has_gidx : {false,true, false}) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx); + RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); + } } } } @@ -183,9 +212,9 @@ TEST(MatMulNBits, Float16) { TEST(MatMulNBits, Float16Large) { for (auto block_size : {16, 32, 64, 128}) { for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f); } } }