From 11af34440a3e216903fcc9dd7e169969fdc24d03 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 13 Oct 2023 16:55:30 -0700 Subject: [PATCH] Add MatMul 4bits support on GPU (#17890) ### Description Add a contrib op MatMulNBits and related toolchain to support quantization on weight. This PR only adds support for 4bits. It: - add schema for contrib op MatMulNBits which can support 1-7 bits quantization on weight. - a naive implementation for 4bits MatMulNBits on CPU and GPU, i.e., implemented like MatMul(A, Dequantize(B)). - a special implementation for GemV for 4bits MatMulNBits and related benchmark tool - tool to quantization model with 4bits. Next: - add general and more efficient kernels for 4bits MatMulNBits on CPU and GPU --- cmake/onnxruntime_rocm_hipify.cmake | 5 + docs/ContribOperators.md | 73 ++++++ docs/OperatorKernels.md | 2 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../cpu/quantization/blockwise_quant_block.h | 129 ++++++++++ .../cpu/quantization/dequantize_blockwise.h | 174 +++++++++++++ .../cpu/quantization/matmul_nbits.cc | 114 +++++++++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../cuda/quantization/dequantize_blockwise.cu | 131 ++++++++++ .../quantization/dequantize_blockwise.cuh | 23 ++ .../cuda/quantization/matmul_nbits.cc | 149 ++++++++++++ .../cuda/quantization/matmul_nbits.cu | 217 +++++++++++++++++ .../cuda/quantization/matmul_nbits.cuh | 27 +++ .../core/graph/contrib_ops/contrib_defs.cc | 78 ++++++ .../shared_library/provider_wrappedtypes.h | 1 + .../python/onnxruntime_pybind_module.cc | 2 + .../python/onnxruntime_pybind_quant.cc | 73 ++++++ .../tools/kernel_explorer/device_array.h | 4 +- .../kernels/cuda/dequant_blockwise_int4.cu | 78 ++++++ .../kernel_explorer/kernels/cuda/gemm.cu | 94 +++++++ .../kernels/cuda/matmul_4bits.cu | 102 ++++++++ .../kernels/dequantize_blockwise_int4.py | 78 ++++++ .../kernel_explorer/kernels/gemm_test.py | 1 + .../kernel_explorer/kernels/matmul_4bits.py | 132 ++++++++++ .../quantization/matmul_4bits_quantizer.py | 229 ++++++++++++++++++ .../test/contrib_ops/matmul_4bits_test.cc | 163 +++++++++++++ .../test/contrib_ops/matmul_fpq4_test.cc | 2 +- .../quantization/test_op_matmul_4bits.py | 164 +++++++++++++ .../test_quantizeblockwise_4bits.py | 141 +++++++++++ 29 files changed, 2389 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu create mode 100644 onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh create mode 100644 onnxruntime/python/onnxruntime_pybind_quant.cc create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py create mode 100644 onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py create mode 100644 onnxruntime/test/contrib_ops/matmul_4bits_test.cc create mode 100644 onnxruntime/test/python/quantization/test_op_matmul_4bits.py create mode 100644 onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6bab3babab0f9..af95d0203544c 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -52,6 +52,11 @@ set(contrib_ops_excluded_files "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" "quantization/attention_quantization_impl.cuh" + "quantization/dequantize_blockwise.cuh" + "quantization/dequantize_blockwise.cu" + "quantization/matmul_nbits.cc" + "quantization/matmul_nbits.cuh" + "quantization/matmul_nbits.cu" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 4367270c7c73b..7e67ec6d0c94e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -50,6 +50,7 @@ Do not modify directly.* * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat + * com.microsoft.MatMulNBits * com.microsoft.MaxpoolWithMask * com.microsoft.MulInteger * com.microsoft.MultiHeadAttention @@ -2634,6 +2635,78 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulNBits** + + MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] + Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
bits : int (required)
+
number of bits used for weight quantization (default 4)
+
block_size : int (required)
+
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
+ +#### Inputs (3 - 4) + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional data blob
+
scales : T1
+
quantization scale
+
zero_points (optional) : T2
+
quantization zero points
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MaxpoolWithMask** For internal use. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b5c8fdb4bfd1a..e2d500006b05f 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -457,6 +457,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -846,6 +847,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 6242d7161f333..b4c51ab290eb7 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -28,6 +28,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -264,6 +265,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h new file mode 100644 index 0000000000000..11b5447d65ed2 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +template +struct alignas(1) BlockwiseQuantBlock { + static_assert(block_size % 8 == 0); + + uint8_t blob_data[block_size / 8 * bits]; + + FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const; + FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const; + + FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N); + FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N); +}; + +template +struct alignas(1) BlockwiseQuantBlock { + static_assert(block_size % 8 == 0); + + uint8_t blob_data[block_size / 2]; + + FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const { + for (int i = 0; i < block_size; i += 2) { + T zp_t = static_cast(float(zp)); + if (k_idx + i < K) { + T x0 = static_cast(float(blob_data[i / 2] & 0xF)); + dst[i] = scale * (x0 - zp_t); + } + if (k_idx + i + 1 < K) { + T x1 = static_cast(float(blob_data[i / 2] >> 4)); + dst[i + 1] = scale * (x1 - zp_t); + } + } + } + + FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const { + constexpr uint8_t zp = 8; + dequant(dst, scale, zp, k_idx, K); + } + + FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) { + float min = static_cast(*src); + float max = static_cast(*src); + int32_t klen = std::min(block_size, K - k_idx); + for (int32_t kk = 0; kk < klen; kk++) { + const float v = static_cast(src[N * kk]); + if (v < min) min = v; + if (v > max) max = v; + } + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + const float scale = (max - min) / ((1 << 4) - 1); + scale_block = static_cast(scale); + + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + float zero_point_fp = min; + if (scale != 0.0f) { + zero_point_fp = 0.f - min / scale; + } + + // Handle any clamping + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > 15.0f) { + zp = 15; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + + for (int32_t kk = 0; kk < klen; kk += 2) { + const float v0 = static_cast(src[N * kk]); + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); + + const float v1 = static_cast((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f); + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); + + blob_data[kk / 2] = vi0 | (vi1 << 4); + } + } + + FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) { + float amax = 0.0f; // abs(max) + float max = 0.0f; + + int32_t klen = std::min(block_size, K - k_idx); + + for (int32_t kk = 0; kk < klen; kk++) { + const float v = static_cast(src[N * kk]); + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float scale = max / (-8.f); + scale_block = static_cast(scale); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + + for (int32_t kk = 0; kk < klen; kk += 2) { + const float v0 = src[N * kk] * reciprocal_scale; + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f))); + + const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0; + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f))); + + blob_data[kk / 2] = vi0 | (vi1 << 4); + } + } +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h new file mode 100644 index 0000000000000..8811e5649fc19 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwise( + uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] + const T* src, // shape: [K, N] + T* scale, // shape: [N * block_per_K] + uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + BlockwiseQuantBlock* dst_blob = + reinterpret_cast*>(dst); + + int32_t block_per_K = (K + block_size - 1) / block_size; + int32_t total_block_count = N * block_per_K; + + std::vector zero_points_tmp; // to avoid race condition + (void)zero_points_tmp; + uint8_t* zero_points_tmp_ptr = zero_points; + if (bits <= 4 && zero_points != nullptr) { + zero_points_tmp.resize(total_block_count, 0); + zero_points_tmp_ptr = zero_points_tmp.data(); + } + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + int32_t n = static_cast(block_idx / block_per_K); + int32_t k_block_idx = static_cast(block_idx % block_per_K); + int32_t k = k_block_idx * block_size; + BlockwiseQuantBlock* blob_ptr = dst_blob + block_idx; + size_t offset = SafeInt(k) * N + n; + if (nullptr != zero_points_tmp_ptr) { + blob_ptr->quant(src + offset, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N); + } else { + blob_ptr->quant(src + offset, scale[block_idx], k, K, N); + } + }, + 0); + + if (bits <= 4 && zero_points != nullptr) { // compact zero points + for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) { + zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4)); + } + if (total_block_count & 1) { + zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] & 0xf0) | zero_points_tmp[total_block_count - 1]; + } + } +} + +#define QuantizeBlockwise4Bits(block_size) \ + QuantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); + +template +void QuantizeBlockwise( + uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] + const T* src, // shape: [K, N] + T* scale, // shape: [N, block_per_K] + uint8_t* zero_points, // shape: [N, block_per_K] + int32_t block_size, + int32_t bits, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); + + if (16 == block_size) { + QuantizeBlockwise4Bits(16); + } else if (32 == block_size) { + QuantizeBlockwise4Bits(32); + } else if (64 == block_size) { + QuantizeBlockwise4Bits(64); + } else if (128 == block_size) { + QuantizeBlockwise4Bits(128); + } else if (256 == block_size) { + QuantizeBlockwise4Bits(256); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwise4Bits + +template +void DequantizeBlockwise( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [N, block_per_K, block_blob_size] + const T* scale, // shape: [N, block_per_K] + const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t block_per_K = (K + block_size - 1) / block_size; + int32_t task_count = N * block_per_K; + + const BlockwiseQuantBlock* src_blob = + reinterpret_cast*>(src); + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + task_count, + [&](ptrdiff_t task_idx) { + int32_t n = static_cast(task_idx / block_per_K); + int32_t k_block_idx = static_cast(task_idx % block_per_K); + int32_t k = k_block_idx * block_size; + const BlockwiseQuantBlock* blob_ptr = src_blob + task_idx; + size_t offset = SafeInt(n) * K + k; + if (nullptr != zero_points) { + if constexpr (bits > 4) { // zero point is stored with a byte + blob_ptr->dequant(dst + offset, scale[task_idx], zero_points[task_idx], k, K); + } else { // zero points is stored with 4bits + uint8_t zp = zero_points[task_idx / 2]; + zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf); + blob_ptr->dequant(dst + offset, scale[task_idx], zp, k, K); + } + } else { + blob_ptr->dequant(dst + offset, scale[task_idx], k, K); + } + }, + 0); +} + +#define DequantizeBlockwise4Bits(block_size) \ + DequantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); + +template +void DequantizeBlockwise( + T* dst, // [N, K] + const uint8_t* src, // [N, block_per_K, block_blob_size] + const T* scale, // [N, block_per_K] + const uint8_t* zero_points, // [N, block_per_K] + int32_t block_size, + int32_t bits, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); + + if (16 == block_size) { + DequantizeBlockwise4Bits(16); + } else if (32 == block_size) { + DequantizeBlockwise4Bits(32); + } else if (64 == block_size) { + DequantizeBlockwise4Bits(64); + } else if (128 == block_size) { + DequantizeBlockwise4Bits(128); + } else if (256 == block_size) { + DequantizeBlockwise4Bits(256); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwise4Bits + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..57aada94be39c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulNBits final : public OpKernel { + public: + MatMulNBits(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; +}; + +Status MatMulNBits::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* b_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwise(tmp_b_data_ptr.get(), + b_data, + scales_data, + zero_points_data, + static_cast(block_size_), + static_cast(nbits_), + static_cast(N_), + static_cast(K_), + thread_pool); + +#if 0 // for debug + auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); +#endif + + TensorShape b_shape({N_, K_}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + const size_t ldb = helper.Ldb(true); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 4bb3a0a2cacb9..52ff285539360 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -114,6 +114,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); @@ -271,6 +273,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu new file mode 100644 index 0000000000000..8c328d00b44d0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -0,0 +1,131 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "dequantize_blockwise.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + + alignas(16) half2 results[4]; + half v0 = __uint2half_rn(values_quant & 0xF); + half v1 = __uint2half_rn((values_quant >> 4) & 0xF); + results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2; + + half v2 = __uint2half_rn((values_quant >> 8) & 0xF); + half v3 = __uint2half_rn((values_quant >> 12) & 0xF); + results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2; + + half v4 = __uint2half_rn((values_quant >> 16) & 0xF); + half v5 = __uint2half_rn((values_quant >> 20) & 0xF); + results[2] = __halves2half2(v4, v5) * scale_half2 + zp_adjust2; + + half v6 = __uint2half_rn((values_quant >> 24) & 0xF); + half v7 = __uint2half_rn((values_quant >> 28) & 0xF); + results[3] = __halves2half2(v6, v7) * scale_half2 + zp_adjust2; + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +} + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) { + float zp_adjust = -scale * zp; + output[0] = float(values_quant & 0xF) * scale + zp_adjust; + output[1] = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + output[2] = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + output[3] = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + output[4] = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + output[5] = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + output[6] = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + output[7] = float((values_quant >> 28) & 0xF) * scale + zp_adjust; +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const uint8_t* zero_points, + int block_size, + int blocks_per_threadblock, + int shift) { + int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + T scale = *(scale_data + block_id); + uint8_t zp = 8; + if (zero_points) { + zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f); + } + + output = output + element_offset; + DequantizeEightElements(quant_value, scale, static_cast(zp), output); +} + +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + int k, + int n, + int block_size, + cudaStream_t stream) { + // k is padded and equal to block_per_K * block_size + ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); + constexpr int element_per_thread = 8; + int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int blocks_per_K = k / block_size; + int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); + int shift = static_cast(log2f(float(block_size))); + + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + blocks_per_threadblock, + shift); + + return Status::OK(); +} + +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh new file mode 100644 index 0000000000000..741ce1e735b42 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..14a8163fef500 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulFp32Q4 operator, it is basically +// matmul float32 with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "matmul_nbits.cuh" +#include "dequantize_blockwise.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; +}; + +template +Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + ORT_ENFORCE(nbits_ == 4, "only 4 bits is supported now"); + + typedef typename ToCudaType::MappedType CudaT; + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + if (!is_4bit_done) { + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + ORT_RETURN_IF_ERROR(Dequantize4Bits(reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); +#if 0 + cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); + T* b_data_cpu = new T[K_ * N_]; + cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); + delete[] b_data_cpu; +#endif + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu new file mode 100644 index 0000000000000..4c3c345076416 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -0,0 +1,217 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "matmul_nbits.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + uint4 vec_a = *(reinterpret_cast(a)); + + half2 element01 = __halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)); + half2 v0 = element01 * scale_half2 + zp_adjust2; + + half2 element23 = __halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)); + half2 v1 = element23 * scale_half2 + zp_adjust2; + + half2 element45 = __halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)); + half2 v2 = element45 * scale_half2 + zp_adjust2; + + half2 element67 = __halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)); + half2 v3 = element67 * scale_half2 + zp_adjust2; + + v0 = v0 * (*(reinterpret_cast(&(vec_a.x)))); + v1 = v1 * (*(reinterpret_cast(&(vec_a.y)))); + v2 = v2 * (*(reinterpret_cast(&(vec_a.z)))) + v0; + v3 = v3 * (*(reinterpret_cast(&(vec_a.w)))) + v1; + v3 = v2 + v3; + return float(v3.x) + float(v3.y); +} + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + + float zp_adjust = -scale * zp; + float v0 = float(values_quant & 0xF) * scale + zp_adjust; + float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + float v2 = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + float v3 = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + float v4 = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust; + + v0 = v0 * a_vec_0.x; + v1 = v1 * a_vec_0.y; + v2 = v2 * a_vec_0.z; + v3 = v3 * a_vec_0.w; + v4 = v4 * a_vec_1.x + v0; + v5 = v5 * a_vec_1.y + v1; + v6 = v6 * a_vec_1.z + v2; + v7 = v7 * a_vec_1.w + v3; + return v4 + v5 + v6 + v7; +} + +constexpr int kColsPerThreadBlock = 8; +constexpr int kWarpSize = 32; + +// kernel for 4bits quantized gemv, i.e., computing A(1,K) x B(K, N) +// B(K, N) is quantized blockwise with 4bits and stored as [N, (K + block_size - 1)/block_size, blob] +// The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1) +// Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], +// i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) +template +__global__ void MatMulFloatInt4Kernel( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int blocks_per_K) { + int n_block_id = blockIdx.x; + int m_id = blockIdx.y; + int lane_id = threadIdx.x; + int warp_id = threadIdx.y; + int n_id = n_block_id * kColsPerThreadBlock + warp_id; + int thread_id = warp_id * kWarpSize + lane_id; + constexpr int k_per_iter = 256; + int k_iter = k / k_per_iter; + + extern __shared__ char shared_buffer[]; + + // load scale to shared buffer + T* b_scale_vec = (T*)shared_buffer; + uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K); + int offset = n_block_id * kColsPerThreadBlock * blocks_per_K; + for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + b_scale_vec[i] = scales_data[offset + i]; + } + for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K / 2; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[offset / 2 + i] : uint8_t(0x88); + } + __syncthreads(); + + a_data += m_id * k; + b_data_quant += n_id * blocks_per_K * (block_size / 2); + + float sum = 0.f; + int k_id = 0; + for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { + uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4)); + int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; + T scale = b_scale_vec[block_idx]; + uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // handle reminder + if (k_id + lane_id * 8 < k) { + uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); + int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; + T scale = b_scale_vec[block_idx]; + uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // warp reduction + for (int i = 16; i > 0; i = i / 2) { + sum += __shfl_down_sync(0xffffffff, sum, i); + } + + if (lane_id == 0) { + output[m_id * n + n_id] = sum; + } +} + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream) { + if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { + return false; + } + dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); + dim3 threads(kWarpSize, kColsPerThreadBlock); + int blocks_per_K = (k + block_size - 1) / block_size; + int blocks_per_thread_block = blocks_per_K * kColsPerThreadBlock; + int shared_mem_size = sizeof(T) * blocks_per_thread_block + blocks_per_thread_block / 2; + if (shared_mem_size > shared_mem_per_block) { + return false; + } + + if (16 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (32 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (64 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (128 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else { + ORT_THROW("block size ", block_size, " is not supported"); + } + + return true; +} + +template bool TryMatMul4Bits( + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +template bool TryMatMul4Bits( + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh new file mode 100644 index 0000000000000..9ccbe4c4d97a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c20498dfb690e..5e5eee568fa21 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2573,6 +2573,35 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, + int64_t K, + int64_t N) { + int input_a_idx = 0; + if (!hasInputShape(ctx, input_a_idx)) { + return; + } + + const auto& a_shape = ctx.getInputType(input_a_idx)->tensor_type().shape(); + if (a_shape.dim_size() == 0) { + fail_shape_inference("Input tensors of wrong rank (0)."); + } + + // TODO: check B shape + + const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); + if (dim_last.has_dim_value() && dim_last.dim_value() != K) { + fail_shape_inference("Incompatible dimensions for matrix multiplication"); + } + + ONNX_NAMESPACE::TensorShapeProto resultShape; + for (int i = 0; i < a_shape.dim_size() - 1; ++i) { + *resultShape.add_dim() = a_shape.dim(i); + } + resultShape.add_dim()->set_dim_value(N); + + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; +} + void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema); ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); @@ -3161,6 +3190,55 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t } }); + static const char* MatMulNBits_ver1_doc = R"DOC( +MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + +Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: +- n_blocks_per_col = (K + block_size - 1) / block_size +- blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBits_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional data blob", "T2") + .Input(2, "scales", "quantization scale", "T1") + .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index f0ab7869b7d50..c0b282b202ef6 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once namespace onnxruntime { extern ProviderHost* g_host; diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index 6824a5d0bf98f..1d8ca195ab82b 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -17,6 +17,7 @@ static constexpr bool HAS_COLLECTIVE_OPS = false; #endif void CreateInferencePybindStateModule(py::module& m); +void CreateQuantPybindModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { CreateInferencePybindStateModule(m); @@ -30,6 +31,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); + CreateQuantPybindModule(m); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc new file mode 100644 index 0000000000000..52ea677d5141d --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "core/util/thread_utils.h" + +namespace pybind11 { +namespace detail { +// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' +constexpr int NPY_FLOAT16 = 23; +template <> +struct npy_format_descriptor { + static constexpr auto name = _("float16"); + static pybind11::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); + return reinterpret_borrow(ptr); + } + static std::string format() { + // following: https://docs.python.org/3/library/struct.html#format-characters + return "e"; + } +}; +} // namespace detail +} // namespace pybind11 + +namespace onnxruntime { +namespace python { + +namespace py = pybind11; +using namespace onnxruntime; + +template +void QuantizeMatMul4BitsBlockwise( + py::array_t dst, // shape: [ N, block_per_K, block_blob_size ] + py::array_t src, // shape: [K, N] + py::array_t scale, // shape: [N, block_per_K] + py::array_t zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] + int32_t block_size, + int32_t N, + int32_t K, + bool is_symmetric) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info scale_buf = scale.request(); + py::buffer_info zp_buf = zero_points.request(); + + contrib::QuantizeBlockwise( + static_cast(dst_buf.ptr), + static_cast(src_buf.ptr), + static_cast(scale_buf.ptr), + is_symmetric ? nullptr : static_cast(zp_buf.ptr), + block_size, + 4, + N, + K, + tp.get()); +} + +void CreateQuantPybindModule(py::module& m) { + m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); +} + +} // namespace python +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/device_array.h b/onnxruntime/python/tools/kernel_explorer/device_array.h index bb868c2b7a59a..12c526fa0c813 100644 --- a/onnxruntime/python/tools/kernel_explorer/device_array.h +++ b/onnxruntime/python/tools/kernel_explorer/device_array.h @@ -62,8 +62,8 @@ class DeviceArray { private: std::shared_ptr device_; void* host_; - ssize_t size_; - ssize_t itemsize_; + py::ssize_t size_; + py::ssize_t itemsize_; }; } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu new file mode 100644 index 0000000000000..9b5e4079a7e31 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct DequantizeInt4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + T* output_; + const uint8_t* quant_; + const T* scales_; + const uint8_t* zero_points_; + int n_; + int k_; +}; + +template +class DequantizeInt4 : public IKernelExplorer { + public: + DequantizeInt4(DeviceArray& output, DeviceArray& quant, DeviceArray& scales, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.quant_ = static_cast(quant.ptr()); + params_.scales_ = static_cast(scales.ptr()); + params_.zero_points_ = nullptr; + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::Dequantize4Bits( + params_.output_, + params_.quant_, + params_.scales_, + params_.zero_points_, + params_.k_, + params_.n_, + 32, + params_.StreamHandle())); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = DequantizeInt4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(DequantizeInt4, half); + REGISTER_OP(DequantizeInt4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu new file mode 100644 index 0000000000000..fd9e9c4fd1612 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include + +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/cuda_stream_handle.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct GemmBenchmarkParams : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + T* output_; + const T* a_; + const T* b_; + int m_; + int n_; + int k_; + cublasHandle_t cublas_handle; +}; + +template +class GemmBenchmark : public IKernelExplorer { + public: + GemmBenchmark(DeviceArray& output, DeviceArray& a, DeviceArray& b, int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + + CUBLAS_CALL_THROW(cublasCreate(&(params_.cublas_handle))); + CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0)); + } + + void Run() override { + typedef typename ToCudaType::MappedType CudaT; + CudaT one = ToCudaType::FromFloat(1.0f); + CudaT zero = ToCudaType::FromFloat(0.0f); + CUBLAS_CALL_THROW(cublasGemmHelper( + params_.cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + params_.n_, params_.m_, params_.k_, + &one, + reinterpret_cast(params_.b_), + params_.n_, + reinterpret_cast(params_.a_), + params_.k_, + &zero, + params_.output_, + params_.n_, + device_prop_)); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = GemmBenchmarkParams; + ParamsT params_{}; + cudaDeviceProp device_prop_; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(GemmBenchmark, half); + REGISTER_OP(GemmBenchmark, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu new file mode 100644 index 0000000000000..9e8c4cd7be36e --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct MatrixFloatInt4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + T* output_; + const T* a_; + const uint8_t* b_; + const T* scales_; + const uint8_t* zero_points_; + int m_; + int n_; + int k_; +}; + +template +class MatrixFloatInt4 : public IKernelExplorer { + public: + MatrixFloatInt4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& scales, + int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.scales_ = static_cast(scales.ptr()); + params_.zero_points_ = nullptr; + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + + CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0)); + } + + MatrixFloatInt4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& scales, + DeviceArray& zeropoints, + int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) { + params_.zero_points_ = static_cast(zeropoints.ptr()); + } + + void Run() override { + contrib::cuda::TryMatMul4Bits( + params_.output_, + params_.a_, + params_.b_, + params_.scales_, + params_.zero_points_, + params_.m_, + params_.n_, + params_.k_, + 32, + static_cast(device_prop_.sharedMemPerBlock), + params_.StreamHandle()); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = MatrixFloatInt4Params; + ParamsT params_{}; + cudaDeviceProp device_prop_; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(MatrixFloatInt4, half); + REGISTER_OP(MatrixFloatInt4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py new file mode 100644 index 0000000000000..7088039f9e531 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "DequantizeInt4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "DequantizeInt4_float" in x, dir(ke))), + } + return type_map[dtype] + + +dtypes = ["float16", "float32"] + + +@dataclass +class DequantizeInt4Metric(ke.BandwidthMetric): + n: int + k: int + + def report(self): + return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} n={self.n} k={self.k} {self.name}" + + +def profile_dequantize_int4_func(n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(n, k).astype(dtype) + quant = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") + scales = np.random.rand(n, (k + 31) // 32).astype(dtype) + + output_d = ke.DeviceArray(output) + quant_d = ke.DeviceArray(quant) + scales_d = ke.DeviceArray(scales) + f = getattr(ke, func) + my_op = f(output_d, quant_d, scales_d, n, k) + duration_ms = my_op.Profile() + total_bytes = (n * k) / 2 + (n * k + n * k / 32) * dtype_to_bytes(dtype) + + ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k)) + + +def profile_with_args(n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_dequantize_int4_func(n, k, dtype, func) + + +def profile(): + for dt in dtypes: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(n, k, dt, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index e378f3e1cc198..8182cdb17567c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -179,6 +179,7 @@ def profile_with_args(dtype, transa, transb, m, n, k, sort): profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) + profile_gemm_func(getattr(ke, "GemmBenchmark" + dtype_suffix), dtype, transa, transb, m, n, k) if ke.is_hipblaslt_available(): profile_gemm_func( getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py new file mode 100644 index 0000000000000..9cb937a13ff27 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "MatrixFloatInt4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "MatrixFloatInt4_float" in x, dir(ke))), + } + return type_map[dtype] + + +def dtype_to_funcs_cublas(dtype): + type_map = { + "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), + "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), + } + return type_map[dtype] + + +dtypes = ["float16", "float32"] + + +@dataclass +class MatrixMulMetric(ke.BandwidthMetric): + m: int + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +@dataclass +class MatrixFpInt4Metric(MatrixMulMetric): + is_symmetric: bool + + def report(self): + return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} is_symmetric={self.is_symmetric} {self.name}" + + +def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") + scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype) + zeropoints = np.random.rand((n * ((k + 31) // 32) + 1) // 2).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + scales_d = ke.DeviceArray(scales) + zeropoints_d = ke.DeviceArray(zeropoints) + f = getattr(ke, func) + + my_op = ( + f(output_d, a_d, b_d, scales_d, m, n, k) + if is_symmetric + else f(output_d, a_d, b_d, scales_d, zeropoints_d, m, n, k) + ) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric)) + + +def profile_gemm_func(m, n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.rand(k, n).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + f = getattr(ke, func) + my_op = f(output_d, a_d, b_d, m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) + + +def profile_with_args(m, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_matmul_fp_int4_func(m, n, k, dtype, func, True) + + for func in dtype_to_funcs(dtype): + profile_matmul_fp_int4_func(m, n, k, dtype, func, False) + + for func in dtype_to_funcs_cublas(dtype): + profile_gemm_func(m, n, k, dtype, func) + + +def profile(): + dims_m = [1] + for dt in dtypes: + for m in dims_m: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(m, n, k, dt, False) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py new file mode 100644 index 0000000000000..fea9e5e8cb739 --- /dev/null +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -0,0 +1,229 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx +from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto + +from onnxruntime.capi._pybind_state import quantize_matmul_4bits + +from .onnx_model import ONNXModel +from .quant_utils import attribute_to_kwarg + +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MatMul4BitsQuantizer: + """Perform 4b quantization of constant MatMul weights""" + + def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None): + if nodes_to_exclude is None: + nodes_to_exclude = [] + self.model = ONNXModel(model) + self.block_size = block_size + self.is_symmetric = is_symmetric + self.nodes_to_exclude = set(nodes_to_exclude) + + @staticmethod + def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: + """4b quantize fp32 weight to a blob""" + + if len(fp32weight.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + rows, cols = fp32weight.shape + + block_size = self.block_size + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") + + # block wise quantization, each block comes from a single column + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) + + return (packed, scales, zero_point) + + def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + + if node.op_type != "MatMul": + return node # only care about MatMul for now + + logger.info(f"start to quantize {node.name} ...") + if node.name in self.nodes_to_exclude: + logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + return node + + inputB = node.input[1] # noqa: N806 + B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + if B is None: + logger.info("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + B_array = onnx.numpy_helper.to_array(B) # noqa: N806 + if len(B_array.shape) != 2: + logger.info("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + + packed, scales, zero_points = self.int4_block_quant(B_array) + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Q4" + for input in Bs_graph.input: + if input.name == inputB: + Bs_graph.input.remove(input) + break + + scales_tensor = onnx.numpy_helper.from_array(scales) + scales_tensor.name = B.name + "_scales" + Bs_graph.initializer.extend([B_quant, scales_tensor]) + + input_names = [node.input[0], B_quant.name, scales_tensor.name] + if not self.is_symmetric: + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = B.name + "_zero_points" + Bs_graph.initializer.extend([zp_tensor]) + input_names.append(zp_tensor.name) + + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = 4 + kwargs["block_size"] = self.block_size + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.info(f"complete quantization of {node.name} ...") + + return matmul_q4_node + + def _process_subgraph(self, graph_stack: List[GraphProto]): + new_nodes = [] + graph = graph_stack[-1] + + for node in graph.node: + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(graph_attrs): + kwargs = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # recursive call to take care of sub-graph + graph_stack.append(attr.g) + kv = {attr.name: self._process_subgraph(graph_stack)} + elif attr.type == onnx.AttributeProto.GRAPHS: + value = [] + for subgraph in attr.graphs: + # recursive call to take care of sub-graph + graph_stack.append(subgraph) + value.extend([self._process_subgraph(graph_stack)]) + kv = {attr.name: value} + else: + kv = attribute_to_kwarg(attr) + kwargs.update(kv) + node = onnx.helper.make_node( # noqa: PLW2901 + node.op_type, node.input, node.output, name=node.name, **kwargs + ) + + new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) + + graph.ClearField("node") + graph.node.extend(new_nodes) + graph_stack.pop() + return graph + + def process(self): + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + + +def parse_args(): + parser = argparse.ArgumentParser( + description="""Blockwise int4 quantization for MatMul 2D weight matrices. + +A weight matrix is partitioned into into blocks, where each block is a +continguous subset inside each column. Each block is quantized into a +set of 4b integers with a scaling factor and an optional offset. +""" + ) + + parser.add_argument("--input_model", required=True, help="Path to the input model file") + parser.add_argument("--output_model", required=True, help="Path to the output model file") + parser.add_argument("--block_size", required=False, default=32) + parser.add_argument( + "--symmetric", required=False, default=True, help="Indicate whether to quantize the model symmetrically" + ) + parser.add_argument("-v", "--verbose", required=False, action="store_true") + parser.set_defaults(verbose=False) + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + input_model_path = args.input_model + output_model_path = args.output_model + + if os.path.exists(output_model_path): + logger.error(f"file {output_model_path} already exists") + raise Exception(f"file {output_model_path} already exists") + + model = onnx.load(input_model_path) + quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude) + quant.process() + quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc new file mode 100644 index 0000000000000..dc8efbbaf3709 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +void QuantizeDequantize(std::vector& raw_vals, + std::vector& quant_vals, + std::vector& scales, + std::vector* zp, + int32_t N, + int32_t K, + int32_t block_size) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + contrib::QuantizeBlockwise( + quant_vals.data(), + raw_vals.data(), + scales.data(), + zp != nullptr ? zp->data() : nullptr, + block_size, + 4, + N, + K, + tp.get()); + + // Note that input1_f_vals is NxK after dequant + contrib::DequantizeBlockwise( + raw_vals.data(), + quant_vals.data(), + scales.data(), + zp != nullptr ? zp->data() : nullptr, + block_size, + 4, + N, + K, + tp.get()); +} + +void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zeropoint, bool use_float16) { + RandomValueGenerator random{1234}; + std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); + std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); + +#if 0 // for Debugging + std::vector input1_f_vals_trans(N * K); + MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); +#endif + + int64_t block_per_k = (K + block_size - 1) / block_size; + int64_t number_of_block = block_per_k * N; + int64_t block_blob_size = block_size * 4 / 8; + int64_t buf_size = number_of_block * (block_size * 4 / 8); + std::vector input1_vals(buf_size); + std::vector scales(number_of_block); + std::vector zp((N * block_per_k + 1) / 2); + + QuantizeDequantize(input1_f_vals, + input1_vals, + scales, + has_zeropoint ? &zp : nullptr, + static_cast(N), + static_cast(K), + static_cast(block_size)); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", 4); + if (use_float16) { + test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); + test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); + test.AddInput("scales", {N * block_per_k}, ToFloat16(scales), true); + if (has_zeropoint) { + test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + } + + test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); + test.SetOutputAbsErr("Y", 0.02f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); + test.AddInput("scales", {N * block_per_k}, scales, true); + if (has_zeropoint) { + test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + } + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); + } +} + +TEST(MatMulNBits, Float32) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(M, N, K, block_size, false, false); + RunTest(M, N, K, block_size, true, false); + } + } + } + } +} + +#if defined(USE_CUDA) +TEST(MatMulNBits, Float16) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(M, N, K, block_size, false, true); + RunTest(M, N, K, block_size, true, true); + } + } + } + } +} + +#endif +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc index dd886ed1c6f5b..09ae5eddb122c 100644 --- a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc @@ -24,7 +24,7 @@ namespace onnxruntime { namespace test { TEST(MatMulFpQ4, MatMul2DSym) { - // (100 x 41) X (41 x 288) + // (100 x 52) X (52 x 288) constexpr int64_t M = 100; constexpr int64_t N = 288; constexpr int64_t K = 52; diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py new file mode 100644 index 0000000000000..02f51cc4fa809 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +import onnx +from onnx import TensorProto, helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import quant_utils + + +class TestOpMatMul4Bits(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmul4bits.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) -> np.ndarray: + line = np.zeros(shape) + line = line.reshape(-1) + + if symmetric: + v = -2.0 + for i in range(line.shape[0]): + if v == 0 or v == -3 or v == 3: + v += 1 + line[i] = v + v += 1 + if v >= 8: + v = -8 + else: + v = 0.0 + for i in range(line.shape[0]): + line[i] = v + v += 1 + if v >= 16: + v = 0 + + return line.reshape(shape) + + def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> None: + # (input) + # | + # MatMul + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + ) + + in_features = 52 + out_features = 288 + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + ) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_4bits_test" + graph = helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test( + self, + model_fp32_path: str, + data_reader: TestDataFeeds, + block_size: int, + is_symmetric: bool, + ): + model_int4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + ) + + # Quantize fp32 model to int4 model + from onnxruntime.quantization import matmul_4bits_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric) + quant.process() + quant.model.save_model_to_file(model_int4_path, False) + + quant_nodes = {"MatMulNBits": 1} + check_op_type_count(self, model_int4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next()) + except Exception as exception: + if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: + # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception + pass + else: + raise exception + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_symmetric(self): + np.random.seed(13) + + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=True) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, True) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_offsets(self): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py new file mode 100644 index 0000000000000..e03a0167d070a --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from importlib.util import find_spec + +import numpy as np +import numpy.typing as npt + + +def dequantize_blockwise_4bits(quant_values, scale, zero_point, valid_len): + blob_size = quant_values.shape[0] + block_size = blob_size * 2 + + quant_float = np.zeros((block_size), dtype=scale.dtype) + for b in range(blob_size): + v = quant_values[b] + quant_float[2 * b] = ((v & 0xF) - zero_point) * scale if 2 * b < valid_len else 0.0 + quant_float[2 * b + 1] = ((v >> 4) - zero_point) * scale if 2 * b + 1 < valid_len else 0.0 + return quant_float + + +def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + matrix_float_padded = matrix_float + if pad_len > 0: + matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") + + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + + matrix_float_padded = np.transpose(matrix_float_padded) + for n in range(cols): + for k_id in range(0, rows, block_size): + if is_symmetric: + amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) + bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) + scale = bmax / (-8.0) + zp = 8 + else: + vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmin = min(vmin, 0.0) + vmax = max(vmax, 0.0) + scale = (vmax - vmin) / ((1 << 4) - 1) + zero_point_fp = vmin + if scale != 0.0: + zero_point_fp = 0.0 - vmin / scale + zp = min(15, max(0, round(zero_point_fp))) + + reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 + block_idx = n * k_blocks + k_id // block_size + scales[block_idx] = scale + zp_pair = zero_point[block_idx // 2] + zero_point[block_idx // 2] = ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + + return (packed, scales, zero_point) + + +def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + k_blocks = (rows + block_size - 1) // block_size + packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + from onnxruntime.capi._pybind_state import quantize_matmul_4bits + + quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) + return (packed, scales, zero_point) + + +class TestQuantizeBlockwise4Bits(unittest.TestCase): + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_blockwise_4bits(self): + for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + for is_symmetric in [True, False]: + matrix_float = np.random.rand(rows, cols).astype(type) + quant_value_ref, scales_ref, zero_point_ref = quantize_blockwise_4bits_ref( + matrix_float, block_size, is_symmetric + ) + quant_value, scales, zero_point = quantize_blockwise_4bits_target( + matrix_float, block_size, is_symmetric + ) + assert np.allclose(scales_ref, scales) + assert np.allclose(zero_point_ref, zero_point) + for c in range(quant_value_ref.shape[0]): + for k in range(quant_value_ref.shape[1]): + block_idx = c * quant_value_ref.shape[1] + k + zp_idx = block_idx // 2 + assert np.allclose( + dequantize_blockwise_4bits( + quant_value_ref[c][k], + scales_ref[block_idx], + (zero_point_ref[zp_idx] >> 4) + if (block_idx & 1) + else (zero_point_ref[zp_idx] & 0x0F), + min(block_size, rows - k * block_size), + ), + dequantize_blockwise_4bits( + quant_value[c][k], + scales[block_idx], + (zero_point[zp_idx] >> 4) if (block_idx & 1) else (zero_point[zp_idx] & 0x0F), + min(block_size, rows - k * block_size), + ), + atol=1.2 * abs(scales[block_idx]), + ) + + +if __name__ == "__main__": + unittest.main()