diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 55d03c14270d3..d72e56d30b554 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -61,6 +61,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
"quantization/attention_quantization_impl.cuh"
+ "quantization/dequantize_blockwise.cuh"
+ "quantization/dequantize_blockwise.cu"
+ "quantization/matmul_nbits.cc"
+ "quantization/matmul_nbits.cuh"
+ "quantization/matmul_nbits.cu"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 2a16bdbf7b55d..751c892336f67 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -49,6 +49,7 @@ Do not modify directly.*
* com.microsoft.MatMulFpQ4
* com.microsoft.MatMulInteger16
* com.microsoft.MatMulIntegerToFloat
+ * com.microsoft.MatMulNBits
* com.microsoft.MaxpoolWithMask
* com.microsoft.MulInteger
* com.microsoft.MultiHeadAttention
@@ -2593,6 +2594,78 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.MatMulNBits**
+
+ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
+ 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
+ 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size.
+ And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
+ 3. Input B's scale and zero point are specified by input scales and zero_points.
+
+ Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
+ - n_blocks_per_col = (K + block_size - 1) / block_size
+ - blob_size = block_size / 8 * bits
+
+ For a block blob. It is stored in format:
+ struct Blob {
+ uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
+ uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
+ uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
+ }
+
+ Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
+ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
+ - [(N * n_blocks_per_col + 1) / 2] if bits <=4
+ - [N * n_blocks_per_col] if bits > 4
+
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- K : int (required)
+- size of each input feature
+- N : int (required)
+- size of each output feature
+- bits : int (required)
+- number of bits used for weight quantization (default 4)
+- block_size : int (required)
+- number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
+
+#### Inputs (3 - 4)
+
+
+- A : T1
+- The input tensor, not quantized
+- B : T2
+- 1-dimensional data blob
+- scales : T1
+- quantization scale
+- zero_points (optional) : T2
+- quantization zero points
+
+
+#### Outputs
+
+
+- Y : T1
+- tensor. The output tensor has the same rank as the input.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(float), tensor(float16)
+- Constrain input and output types to float/half_float tensors.
+- T2 : tensor(uint8)
+- Constrain quantized weight types to uint8.
+
+
+
### **com.microsoft.MaxpoolWithMask**
For internal use.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index ce9d8aabfede3..47c3368155ff3 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -457,6 +457,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)|
+|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)|
@@ -844,6 +845,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 0ec5088808656..e8f4c91cc0d04 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -27,6 +27,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
@@ -262,6 +263,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo,
#endif
diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h
new file mode 100644
index 0000000000000..11b5447d65ed2
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h
@@ -0,0 +1,129 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+#if defined(_MSC_VER)
+#define FORCEINLINE __forceinline
+#else
+#define FORCEINLINE __attribute__((always_inline)) inline
+#endif
+
+template
+struct alignas(1) BlockwiseQuantBlock {
+ static_assert(block_size % 8 == 0);
+
+ uint8_t blob_data[block_size / 8 * bits];
+
+ FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const;
+ FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const;
+
+ FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N);
+ FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N);
+};
+
+template
+struct alignas(1) BlockwiseQuantBlock {
+ static_assert(block_size % 8 == 0);
+
+ uint8_t blob_data[block_size / 2];
+
+ FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const {
+ for (int i = 0; i < block_size; i += 2) {
+ T zp_t = static_cast(float(zp));
+ if (k_idx + i < K) {
+ T x0 = static_cast(float(blob_data[i / 2] & 0xF));
+ dst[i] = scale * (x0 - zp_t);
+ }
+ if (k_idx + i + 1 < K) {
+ T x1 = static_cast(float(blob_data[i / 2] >> 4));
+ dst[i + 1] = scale * (x1 - zp_t);
+ }
+ }
+ }
+
+ FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const {
+ constexpr uint8_t zp = 8;
+ dequant(dst, scale, zp, k_idx, K);
+ }
+
+ FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) {
+ float min = static_cast(*src);
+ float max = static_cast(*src);
+ int32_t klen = std::min(block_size, K - k_idx);
+ for (int32_t kk = 0; kk < klen; kk++) {
+ const float v = static_cast(src[N * kk]);
+ if (v < min) min = v;
+ if (v > max) max = v;
+ }
+ min = std::min(min, 0.0f);
+ max = std::max(max, 0.0f);
+
+ const float scale = (max - min) / ((1 << 4) - 1);
+ scale_block = static_cast(scale);
+
+ const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;
+ float zero_point_fp = min;
+ if (scale != 0.0f) {
+ zero_point_fp = 0.f - min / scale;
+ }
+
+ // Handle any clamping
+ if (zero_point_fp < 0.0f) {
+ zp = 0;
+ } else if (zero_point_fp > 15.0f) {
+ zp = 15;
+ } else {
+ zp = (uint8_t)roundf(zero_point_fp);
+ }
+
+ for (int32_t kk = 0; kk < klen; kk += 2) {
+ const float v0 = static_cast(src[N * kk]);
+ const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp)));
+
+ const float v1 = static_cast((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f);
+ const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp)));
+
+ blob_data[kk / 2] = vi0 | (vi1 << 4);
+ }
+ }
+
+ FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) {
+ float amax = 0.0f; // abs(max)
+ float max = 0.0f;
+
+ int32_t klen = std::min(block_size, K - k_idx);
+
+ for (int32_t kk = 0; kk < klen; kk++) {
+ const float v = static_cast(src[N * kk]);
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float scale = max / (-8.f);
+ scale_block = static_cast(scale);
+ const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;
+
+ for (int32_t kk = 0; kk < klen; kk += 2) {
+ const float v0 = src[N * kk] * reciprocal_scale;
+ const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f)));
+
+ const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0;
+ const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f)));
+
+ blob_data[kk / 2] = vi0 | (vi1 << 4);
+ }
+ }
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h
new file mode 100644
index 0000000000000..8811e5649fc19
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h
@@ -0,0 +1,174 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "blockwise_quant_block.h"
+
+#include
+
+#include "core/common/safeint.h"
+#include "core/framework/float16.h"
+#include "core/platform/threadpool.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+void QuantizeBlockwise(
+ uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ]
+ const T* src, // shape: [K, N]
+ T* scale, // shape: [N * block_per_K]
+ uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2]
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ BlockwiseQuantBlock* dst_blob =
+ reinterpret_cast*>(dst);
+
+ int32_t block_per_K = (K + block_size - 1) / block_size;
+ int32_t total_block_count = N * block_per_K;
+
+ std::vector zero_points_tmp; // to avoid race condition
+ (void)zero_points_tmp;
+ uint8_t* zero_points_tmp_ptr = zero_points;
+ if (bits <= 4 && zero_points != nullptr) {
+ zero_points_tmp.resize(total_block_count, 0);
+ zero_points_tmp_ptr = zero_points_tmp.data();
+ }
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ thread_pool,
+ total_block_count,
+ [&](ptrdiff_t block_idx) {
+ int32_t n = static_cast(block_idx / block_per_K);
+ int32_t k_block_idx = static_cast(block_idx % block_per_K);
+ int32_t k = k_block_idx * block_size;
+ BlockwiseQuantBlock* blob_ptr = dst_blob + block_idx;
+ size_t offset = SafeInt(k) * N + n;
+ if (nullptr != zero_points_tmp_ptr) {
+ blob_ptr->quant(src + offset, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N);
+ } else {
+ blob_ptr->quant(src + offset, scale[block_idx], k, K, N);
+ }
+ },
+ 0);
+
+ if (bits <= 4 && zero_points != nullptr) { // compact zero points
+ for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) {
+ zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4));
+ }
+ if (total_block_count & 1) {
+ zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] & 0xf0) | zero_points_tmp[total_block_count - 1];
+ }
+ }
+}
+
+#define QuantizeBlockwise4Bits(block_size) \
+ QuantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool);
+
+template
+void QuantizeBlockwise(
+ uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ]
+ const T* src, // shape: [K, N]
+ T* scale, // shape: [N, block_per_K]
+ uint8_t* zero_points, // shape: [N, block_per_K]
+ int32_t block_size,
+ int32_t bits,
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ ORT_ENFORCE(bits == 4, "only 4 bits is supported now");
+
+ if (16 == block_size) {
+ QuantizeBlockwise4Bits(16);
+ } else if (32 == block_size) {
+ QuantizeBlockwise4Bits(32);
+ } else if (64 == block_size) {
+ QuantizeBlockwise4Bits(64);
+ } else if (128 == block_size) {
+ QuantizeBlockwise4Bits(128);
+ } else if (256 == block_size) {
+ QuantizeBlockwise4Bits(256);
+ } else {
+ ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
+ }
+}
+
+#undef QuantizeBlockwise4Bits
+
+template
+void DequantizeBlockwise(
+ T* dst, // shape: [N, K]
+ const uint8_t* src, // shape: [N, block_per_K, block_blob_size]
+ const T* scale, // shape: [N, block_per_K]
+ const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2]
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ int32_t block_per_K = (K + block_size - 1) / block_size;
+ int32_t task_count = N * block_per_K;
+
+ const BlockwiseQuantBlock* src_blob =
+ reinterpret_cast*>(src);
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ thread_pool,
+ task_count,
+ [&](ptrdiff_t task_idx) {
+ int32_t n = static_cast(task_idx / block_per_K);
+ int32_t k_block_idx = static_cast(task_idx % block_per_K);
+ int32_t k = k_block_idx * block_size;
+ const BlockwiseQuantBlock* blob_ptr = src_blob + task_idx;
+ size_t offset = SafeInt(n) * K + k;
+ if (nullptr != zero_points) {
+ if constexpr (bits > 4) { // zero point is stored with a byte
+ blob_ptr->dequant(dst + offset, scale[task_idx], zero_points[task_idx], k, K);
+ } else { // zero points is stored with 4bits
+ uint8_t zp = zero_points[task_idx / 2];
+ zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf);
+ blob_ptr->dequant(dst + offset, scale[task_idx], zp, k, K);
+ }
+ } else {
+ blob_ptr->dequant(dst + offset, scale[task_idx], k, K);
+ }
+ },
+ 0);
+}
+
+#define DequantizeBlockwise4Bits(block_size) \
+ DequantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool);
+
+template
+void DequantizeBlockwise(
+ T* dst, // [N, K]
+ const uint8_t* src, // [N, block_per_K, block_blob_size]
+ const T* scale, // [N, block_per_K]
+ const uint8_t* zero_points, // [N, block_per_K]
+ int32_t block_size,
+ int32_t bits,
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ ORT_ENFORCE(bits == 4, "only 4 bits is supported now");
+
+ if (16 == block_size) {
+ DequantizeBlockwise4Bits(16);
+ } else if (32 == block_size) {
+ DequantizeBlockwise4Bits(32);
+ } else if (64 == block_size) {
+ DequantizeBlockwise4Bits(64);
+ } else if (128 == block_size) {
+ DequantizeBlockwise4Bits(128);
+ } else if (256 == block_size) {
+ DequantizeBlockwise4Bits(256);
+ } else {
+ ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
+ }
+}
+
+#undef DequantizeBlockwise4Bits
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
new file mode 100644
index 0000000000000..57aada94be39c
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -0,0 +1,114 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/safeint.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "core/providers/common.h"
+#include "dequantize_blockwise.h"
+#include "core/mlas/inc/mlas.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+class MatMulNBits final : public OpKernel {
+ public:
+ MatMulNBits(const OpKernelInfo& info) : OpKernel(info) {
+ ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_));
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t nbits_;
+};
+
+Status MatMulNBits::Compute(OpKernelContext* ctx) const {
+ concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
+
+ const Tensor* a = ctx->Input(0);
+ const Tensor* b = ctx->Input(1);
+ const Tensor* scales = ctx->Input(2);
+ const Tensor* zero_points = ctx->Input(3);
+
+ const auto* a_data = a->Data();
+ const uint8_t* b_data = b->Data();
+ const auto* scales_data = scales->Data();
+ const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data();
+
+ AllocatorPtr allocator;
+ auto status = ctx->GetTempSpaceAllocator(&allocator);
+ ORT_RETURN_IF_ERROR(status);
+ auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_);
+ DequantizeBlockwise(tmp_b_data_ptr.get(),
+ b_data,
+ scales_data,
+ zero_points_data,
+ static_cast(block_size_),
+ static_cast(nbits_),
+ static_cast(N_),
+ static_cast(K_),
+ thread_pool);
+
+#if 0 // for debug
+ auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_);
+ MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_);
+#endif
+
+ TensorShape b_shape({N_, K_});
+
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
+
+ Tensor* y = ctx->Output(0, helper.OutputShape());
+
+ // Bail out early if the output is going to be empty
+ if (y->Shape().Size() == 0)
+ return Status::OK();
+
+ auto* y_data = y->MutableData();
+
+ const size_t max_len = helper.OutputOffsets().size();
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ const size_t lda = helper.Lda(false);
+ const size_t ldb = helper.Ldb(true);
+
+ // TODO: implement with native kernel
+ std::vector data(max_len);
+ for (size_t i = 0; i < max_len; i++) {
+ data[i].BIsPacked = false;
+ data[i].A = a_data + helper.LeftOffsets()[i];
+ data[i].lda = lda;
+ data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i];
+ data[i].ldb = ldb;
+ data[i].C = y_data + helper.OutputOffsets()[i];
+ data[i].ldc = N;
+ data[i].alpha = 1.f;
+ data[i].beta = 0.0f;
+ }
+ MlasGemmBatch(CblasNoTrans, CblasTrans,
+ M, N, K, data.data(), max_len, thread_pool);
+
+ return Status::OK();
+}
+
+ONNX_OPERATOR_KERNEL_EX(
+ MatMulNBits,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulNBits);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 71ee5ae1ddbe6..0595350c49827 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -113,6 +113,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
@@ -264,6 +266,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
new file mode 100644
index 0000000000000..8c328d00b44d0
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
@@ -0,0 +1,131 @@
+// Modifications: scaling is moved from masked softmax to the gemm before that.
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+#include
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/cuda_common.h"
+#include "dequantize_blockwise.cuh"
+
+using namespace onnxruntime::cuda;
+using namespace cub;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) {
+ half2 scale_half2 = {scale, scale};
+ half zp_adjust = -scale * __short2half_rn(zp);
+ half2 zp_adjust2 = {zp_adjust, zp_adjust};
+
+ alignas(16) half2 results[4];
+ half v0 = __uint2half_rn(values_quant & 0xF);
+ half v1 = __uint2half_rn((values_quant >> 4) & 0xF);
+ results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2;
+
+ half v2 = __uint2half_rn((values_quant >> 8) & 0xF);
+ half v3 = __uint2half_rn((values_quant >> 12) & 0xF);
+ results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2;
+
+ half v4 = __uint2half_rn((values_quant >> 16) & 0xF);
+ half v5 = __uint2half_rn((values_quant >> 20) & 0xF);
+ results[2] = __halves2half2(v4, v5) * scale_half2 + zp_adjust2;
+
+ half v6 = __uint2half_rn((values_quant >> 24) & 0xF);
+ half v7 = __uint2half_rn((values_quant >> 28) & 0xF);
+ results[3] = __halves2half2(v6, v7) * scale_half2 + zp_adjust2;
+ *(reinterpret_cast(output)) = *(reinterpret_cast(results));
+}
+
+__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) {
+ float zp_adjust = -scale * zp;
+ output[0] = float(values_quant & 0xF) * scale + zp_adjust;
+ output[1] = float((values_quant >> 4) & 0xF) * scale + zp_adjust;
+ output[2] = float((values_quant >> 8) & 0xF) * scale + zp_adjust;
+ output[3] = float((values_quant >> 12) & 0xF) * scale + zp_adjust;
+ output[4] = float((values_quant >> 16) & 0xF) * scale + zp_adjust;
+ output[5] = float((values_quant >> 20) & 0xF) * scale + zp_adjust;
+ output[6] = float((values_quant >> 24) & 0xF) * scale + zp_adjust;
+ output[7] = float((values_quant >> 28) & 0xF) * scale + zp_adjust;
+}
+
+template
+__global__ void Dequantize4BitsKernel(
+ T* output,
+ const uint8_t* quant_data,
+ const T* scale_data,
+ const uint8_t* zero_points,
+ int block_size,
+ int blocks_per_threadblock,
+ int shift) {
+ int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
+ int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
+ uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2));
+ T scale = *(scale_data + block_id);
+ uint8_t zp = 8;
+ if (zero_points) {
+ zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f);
+ }
+
+ output = output + element_offset;
+ DequantizeEightElements(quant_value, scale, static_cast(zp), output);
+}
+
+template
+Status Dequantize4Bits(
+ T* output,
+ const uint8_t* quant_data,
+ const T* scales_data,
+ const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2]
+ int k,
+ int n,
+ int block_size,
+ cudaStream_t stream) {
+ // k is padded and equal to block_per_K * block_size
+ ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
+ constexpr int element_per_thread = 8;
+ int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
+ int blocks_per_K = k / block_size;
+ int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
+ int shift = static_cast(log2f(float(block_size)));
+
+ Dequantize4BitsKernel<<>>(
+ output,
+ quant_data,
+ scales_data,
+ zero_points,
+ block_size,
+ blocks_per_threadblock,
+ shift);
+
+ return Status::OK();
+}
+
+template Status Dequantize4Bits(
+ float* output,
+ const uint8_t* quant_data,
+ const float* scales_data,
+ const uint8_t* zero_points,
+ int k,
+ int n,
+ int block_size,
+ cudaStream_t stream);
+
+template Status Dequantize4Bits(
+ half* output,
+ const uint8_t* quant_data,
+ const half* scales_data,
+ const uint8_t* zero_points,
+ int k,
+ int n,
+ int block_size,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
new file mode 100644
index 0000000000000..741ce1e735b42
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+template
+Status Dequantize4Bits(
+ T* output,
+ const uint8_t* quant_data,
+ const T* scales_data,
+ const uint8_t* zero_points,
+ int k,
+ int n,
+ int block_size,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
new file mode 100644
index 0000000000000..1f540fa45e7a8
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
@@ -0,0 +1,148 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+//
+// This module define MatMulFp32Q4 operator, it is basically
+// matmul float32 with right hand side being a 2-D matrix
+// pre-packed and block-compacted into int4
+//
+
+#include "core/common/safeint.h"
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "matmul_nbits.cuh"
+#include "dequantize_blockwise.cuh"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+using namespace onnxruntime::cuda;
+
+template
+class MatMulNBits final : public CudaKernel {
+ public:
+ MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
+ ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_));
+ }
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t nbits_;
+};
+
+template
+Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const {
+ const Tensor* a = ctx->Input(0);
+ const Tensor* b = ctx->Input(1);
+ const Tensor* scales = ctx->Input(2);
+ const Tensor* zero_points = ctx->Input(3);
+
+ const auto* a_data = a->Data();
+ const uint8_t* blob_data = b->Data();
+ const auto* scales_data = scales->Data();
+ const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data();
+
+ ORT_ENFORCE(nbits_ == 4, "only 4 bits is supported now");
+
+ typedef typename ToCudaType::MappedType CudaT;
+
+ constexpr bool transa = false;
+ constexpr bool transb = true;
+ MatMulComputeHelper helper;
+ TensorShape b_shape({N_, K_});
+ ORT_RETURN_IF_ERROR(
+ helper.Compute(a->Shape(), b_shape, transa, transb));
+
+ Tensor* Y = ctx->Output(0, helper.OutputShape());
+ // Bail out early if the output is going to be empty
+ if (Y->Shape().Size() == 0) return Status::OK();
+
+ bool is_4bit_done = TryMatMul4Bits(
+ reinterpret_cast(Y->MutableData()),
+ reinterpret_cast(a_data),
+ blob_data,
+ reinterpret_cast(scales_data),
+ zero_points_data,
+ SafeInt(helper.M()),
+ SafeInt(helper.N()),
+ SafeInt(helper.K()),
+ SafeInt(block_size_),
+ static_cast(ctx->GetComputeStream()->GetHandle()));
+ if (!is_4bit_done) {
+ int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
+ IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream());
+ auto* b_data = b_data_ptr.get();
+ ORT_RETURN_IF_ERROR(Dequantize4Bits(reinterpret_cast(b_data),
+ blob_data,
+ reinterpret_cast(scales_data),
+ zero_points_data,
+ SafeInt(K_padded),
+ SafeInt(N_),
+ SafeInt(block_size_),
+ static_cast(ctx->GetComputeStream()->GetHandle())));
+#if 0
+ cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()));
+ T* b_data_cpu = new T[K_ * N_];
+ cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost);
+ delete[] b_data_cpu;
+#endif
+
+ const CudaT alpha = ToCudaType::FromFloat(1.f);
+ const CudaT zero = ToCudaType::FromFloat(0.f);
+
+ if (helper.OutputOffsets().size() == 1) {
+ CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
+ GetCublasHandle(ctx),
+ CUBLAS_OP_T,
+ CUBLAS_OP_N,
+ SafeInt(helper.N()),
+ SafeInt(helper.M()),
+ SafeInt(helper.K()),
+ &alpha,
+ reinterpret_cast(b_data),
+ SafeInt(K_padded),
+ reinterpret_cast(a_data),
+ helper.Lda(transa),
+ &zero,
+ reinterpret_cast(Y->MutableData()),
+ helper.Ldc(),
+ GetDeviceProp()));
+ }
+ }
+
+ return Status::OK();
+}
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ MatMulNBits,
+ kMSDomain,
+ 1,
+ float,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulNBits);
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ MatMulNBits,
+ kMSDomain,
+ 1,
+ MLFloat16,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulNBits);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
new file mode 100644
index 0000000000000..c5a1e0f9ee451
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
@@ -0,0 +1,209 @@
+// Modifications: scaling is moved from masked softmax to the gemm before that.
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/cuda_common.h"
+#include "matmul_nbits.cuh"
+
+using namespace onnxruntime::cuda;
+using namespace cub;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) {
+ half2 scale_half2 = {scale, scale};
+ half zp_adjust = -scale * __short2half_rn(zp);
+ half2 zp_adjust2 = {zp_adjust, zp_adjust};
+ uint4 vec_a = *(reinterpret_cast(a));
+
+ half2 element01 = __halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF));
+ half2 v0 = element01 * scale_half2 + zp_adjust2;
+
+ half2 element23 = __halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF));
+ half2 v1 = element23 * scale_half2 + zp_adjust2;
+
+ half2 element45 = __halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF));
+ half2 v2 = element45 * scale_half2 + zp_adjust2;
+
+ half2 element67 = __halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF));
+ half2 v3 = element67 * scale_half2 + zp_adjust2;
+
+ v0 = v0 * (*(reinterpret_cast(&(vec_a.x))));
+ v1 = v1 * (*(reinterpret_cast(&(vec_a.y))));
+ v2 = v2 * (*(reinterpret_cast(&(vec_a.z)))) + v0;
+ v3 = v3 * (*(reinterpret_cast(&(vec_a.w)))) + v1;
+ v3 = v2 + v3;
+ return float(v3.x) + float(v3.y);
+}
+
+__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) {
+ float4 a_vec_0 = *(reinterpret_cast(a));
+ float4 a_vec_1 = *(reinterpret_cast(a + 4));
+
+ float zp_adjust = -scale * zp;
+ float v0 = float(values_quant & 0xF) * scale + zp_adjust;
+ float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust;
+ float v2 = float((values_quant >> 8) & 0xF) * scale + zp_adjust;
+ float v3 = float((values_quant >> 12) & 0xF) * scale + zp_adjust;
+ float v4 = float((values_quant >> 16) & 0xF) * scale + zp_adjust;
+ float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust;
+ float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust;
+ float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust;
+
+ v0 = v0 * a_vec_0.x;
+ v1 = v1 * a_vec_0.y;
+ v2 = v2 * a_vec_0.z;
+ v3 = v3 * a_vec_0.w;
+ v4 = v4 * a_vec_1.x + v0;
+ v5 = v5 * a_vec_1.y + v1;
+ v6 = v6 * a_vec_1.z + v2;
+ v7 = v7 * a_vec_1.w + v3;
+ return v4 + v5 + v6 + v7;
+}
+
+constexpr int kColsPerThreadBlock = 8;
+constexpr int kWarpSize = 32;
+
+// kernel for 4bits quantized gemv, i.e., computing A(1,K) x B(K, N)
+// B(K, N) is quantized blockwise with 4bits and stored as [N, (K + block_size - 1)/block_size, blob]
+// The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1)
+// Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob],
+// i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K)
+template
+__global__ void MatMulFloatInt4Kernel(
+ T* output,
+ const T* a_data,
+ const uint8_t* b_data_quant,
+ const T* scales_data,
+ const uint8_t* zero_points,
+ int m,
+ int n,
+ int k) {
+ int n_block_id = blockIdx.x;
+ int m_id = blockIdx.y;
+ int lane_id = threadIdx.x;
+ int warp_id = threadIdx.y;
+ int n_id = n_block_id * kColsPerThreadBlock + warp_id;
+ int blocks_per_K = (k + block_size - 1) / block_size;
+ int thread_id = warp_id * kWarpSize + lane_id;
+ constexpr int k_per_iter = 256;
+ int k_iter = k / k_per_iter;
+
+ extern __shared__ char shared_buffer[];
+
+ // load scale to shared buffer
+ T* b_scale_vec = (T*)shared_buffer;
+ uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K);
+ int offset = n_block_id * kColsPerThreadBlock * blocks_per_K;
+ for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) {
+ b_scale_vec[i] = scales_data[offset + i];
+ }
+ for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K / 2; i += kColsPerThreadBlock * kWarpSize) {
+ b_zp_vec[i] = zero_points != nullptr ? zero_points[offset / 2 + i] : uint8_t(0x88);
+ }
+ __syncthreads();
+
+ a_data += m_id * k;
+ b_data_quant += n_id * blocks_per_K * (block_size / 2);
+
+ float sum = 0.f;
+ int k_id = 0;
+ for (; k_id < (k & 0xffffff00); k_id += k_per_iter) {
+ uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4));
+ int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size;
+ T scale = b_scale_vec[block_idx];
+ uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f);
+ sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3));
+ }
+
+ // handle reminder
+ if (k_id + lane_id * 8 < k) {
+ uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4));
+ int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size;
+ T scale = b_scale_vec[block_idx];
+ uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f);
+ sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3));
+ }
+
+ // warp reduction
+ for (int i = 16; i > 0; i = i / 2) {
+ sum += __shfl_down_sync(0xffffffff, sum, i);
+ }
+
+ if (lane_id == 0) {
+ output[m_id * n + n_id] = sum;
+ }
+}
+
+template
+bool TryMatMul4Bits(
+ T* output,
+ const T* a_data,
+ const uint8_t* b_data_quant,
+ const T* scales_data,
+ const uint8_t* zero_points,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream) {
+ if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) {
+ return false;
+ }
+ dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m);
+ dim3 threads(kWarpSize, kColsPerThreadBlock);
+ int shared_mem_size = (sizeof(T) + 1) * ((k + block_size - 1) / block_size * kColsPerThreadBlock);
+
+ if (16 == block_size) {
+ MatMulFloatInt4Kernel<<>>(
+ output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
+ } else if (32 == block_size) {
+ MatMulFloatInt4Kernel<<>>(
+ output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
+ } else if (64 == block_size) {
+ MatMulFloatInt4Kernel<<>>(
+ output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
+ } else if (128 == block_size) {
+ MatMulFloatInt4Kernel<<>>(
+ output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
+ } else {
+ ORT_THROW("block size ", block_size, " is not supported");
+ }
+
+ return true;
+}
+
+template bool TryMatMul4Bits(
+ float* output,
+ const float* a_data,
+ const uint8_t* b_data_quant,
+ const float* scales_data,
+ const uint8_t* zero_points,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream);
+
+template bool TryMatMul4Bits(
+ half* output,
+ const half* a_data,
+ const uint8_t* b_data_quant,
+ const half* scales_data,
+ const uint8_t* zero_points,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh
new file mode 100644
index 0000000000000..847a549c342c9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+bool TryMatMul4Bits(
+ T* output,
+ const T* a_data,
+ const uint8_t* b_data_quant,
+ const T* scales_data,
+ const uint8_t* zero_points,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 21f9db7e486be..03930d76467ad 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -2384,6 +2384,35 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1,
a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth].
The resizing is corner aligned.)DOC"));
+static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
+ int64_t K,
+ int64_t N) {
+ int input_a_idx = 0;
+ if (!hasInputShape(ctx, input_a_idx)) {
+ return;
+ }
+
+ const auto& a_shape = ctx.getInputType(input_a_idx)->tensor_type().shape();
+ if (a_shape.dim_size() == 0) {
+ fail_shape_inference("Input tensors of wrong rank (0).");
+ }
+
+ // TODO: check B shape
+
+ const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1);
+ if (dim_last.has_dim_value() && dim_last.dim_value() != K) {
+ fail_shape_inference("Incompatible dimensions for matrix multiplication");
+ }
+
+ ONNX_NAMESPACE::TensorShapeProto resultShape;
+ for (int i = 0; i < a_shape.dim_size() - 1; ++i) {
+ *resultShape.add_dim() = a_shape.dim(i);
+ }
+ resultShape.add_dim()->set_dim_value(N);
+
+ *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape;
+}
+
void RegisterContribSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
@@ -2972,6 +3001,55 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t
}
});
+ static const char* MatMulNBits_ver1_doc = R"DOC(
+MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
+ 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
+ 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size.
+ And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
+ 3. Input B's scale and zero point are specified by input scales and zero_points.
+
+Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
+- n_blocks_per_col = (K + block_size - 1) / block_size
+- blob_size = block_size / 8 * bits
+
+ For a block blob. It is stored in format:
+ struct Blob {
+ uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
+ uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
+ uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
+ }
+
+Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
+Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
+ - [(N * n_blocks_per_col + 1) / 2] if bits <=4
+ - [N * n_blocks_per_col] if bits > 4
+
+)DOC";
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(MatMulNBits_ver1_doc)
+ .Attr("K", "size of each input feature", AttributeProto::INT)
+ .Attr("N", "size of each output feature", AttributeProto::INT)
+ .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
+ .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
+ .Input(0, "A", "The input tensor, not quantized", "T1")
+ .Input(1, "B", "1-dimensional data blob", "T2")
+ .Input(2, "scales", "quantization scale", "T1")
+ .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional)
+ .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
+ .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
+ .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ // Type inference
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ // Shape inference
+ int64_t in_features = getAttribute(ctx, "K", -1);
+ int64_t out_features = getAttribute(ctx, "N", -1);
+ MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
+ });
+
#ifdef ENABLE_ATEN
ONNX_CONTRIB_OPERATOR_SCHEMA(ATen)
.SetDomain(kPytorchAtenDomain)
diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
index f0ab7869b7d50..c0b282b202ef6 100644
--- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
+++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#pragma once
namespace onnxruntime {
extern ProviderHost* g_host;
diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc
index 6824a5d0bf98f..1d8ca195ab82b 100644
--- a/onnxruntime/python/onnxruntime_pybind_module.cc
+++ b/onnxruntime/python/onnxruntime_pybind_module.cc
@@ -17,6 +17,7 @@ static constexpr bool HAS_COLLECTIVE_OPS = false;
#endif
void CreateInferencePybindStateModule(py::module& m);
+void CreateQuantPybindModule(py::module& m);
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
CreateInferencePybindStateModule(m);
@@ -30,6 +31,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
m.def("get_version_string", []() -> std::string { return ORT_VERSION; });
m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; });
m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; });
+ CreateQuantPybindModule(m);
}
} // namespace python
} // namespace onnxruntime
diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc
new file mode 100644
index 0000000000000..52ea677d5141d
--- /dev/null
+++ b/onnxruntime/python/onnxruntime_pybind_quant.cc
@@ -0,0 +1,73 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+
+#include "contrib_ops/cpu/quantization/dequantize_blockwise.h"
+#include "core/util/thread_utils.h"
+
+namespace pybind11 {
+namespace detail {
+// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
+constexpr int NPY_FLOAT16 = 23;
+template <>
+struct npy_format_descriptor {
+ static constexpr auto name = _("float16");
+ static pybind11::dtype dtype() {
+ handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
+ return reinterpret_borrow(ptr);
+ }
+ static std::string format() {
+ // following: https://docs.python.org/3/library/struct.html#format-characters
+ return "e";
+ }
+};
+} // namespace detail
+} // namespace pybind11
+
+namespace onnxruntime {
+namespace python {
+
+namespace py = pybind11;
+using namespace onnxruntime;
+
+template
+void QuantizeMatMul4BitsBlockwise(
+ py::array_t dst, // shape: [ N, block_per_K, block_blob_size ]
+ py::array_t src, // shape: [K, N]
+ py::array_t scale, // shape: [N, block_per_K]
+ py::array_t zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2]
+ int32_t block_size,
+ int32_t N,
+ int32_t K,
+ bool is_symmetric) {
+ OrtThreadPoolParams to;
+ auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
+ concurrency::ThreadPoolType::INTRA_OP);
+
+ py::buffer_info dst_buf = dst.request();
+ py::buffer_info src_buf = src.request();
+ py::buffer_info scale_buf = scale.request();
+ py::buffer_info zp_buf = zero_points.request();
+
+ contrib::QuantizeBlockwise(
+ static_cast(dst_buf.ptr),
+ static_cast(src_buf.ptr),
+ static_cast(scale_buf.ptr),
+ is_symmetric ? nullptr : static_cast(zp_buf.ptr),
+ block_size,
+ 4,
+ N,
+ K,
+ tp.get());
+}
+
+void CreateQuantPybindModule(py::module& m) {
+ m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise);
+ m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise);
+}
+
+} // namespace python
+} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/device_array.h b/onnxruntime/python/tools/kernel_explorer/device_array.h
index bb868c2b7a59a..12c526fa0c813 100644
--- a/onnxruntime/python/tools/kernel_explorer/device_array.h
+++ b/onnxruntime/python/tools/kernel_explorer/device_array.h
@@ -62,8 +62,8 @@ class DeviceArray {
private:
std::shared_ptr device_;
void* host_;
- ssize_t size_;
- ssize_t itemsize_;
+ py::ssize_t size_;
+ py::ssize_t itemsize_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu
new file mode 100644
index 0000000000000..9b5e4079a7e31
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu
@@ -0,0 +1,78 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// This file serve as a simple example for adding a tunable op to onnxruntime.
+
+#include
+#include
+#include
+
+#include
+
+#include "core/providers/cuda/tunable/cuda_tunable.h"
+#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
+#include "python/tools/kernel_explorer/device_array.h"
+#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh"
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+// Extend the OpParams so that all specializations have the same parameter passing interface
+template
+struct DequantizeInt4Params : cuda::tunable::OpParams {
+ std::string Signature() const override { return std::to_string(n_); }
+
+ T* output_;
+ const uint8_t* quant_;
+ const T* scales_;
+ const uint8_t* zero_points_;
+ int n_;
+ int k_;
+};
+
+template
+class DequantizeInt4 : public IKernelExplorer {
+ public:
+ DequantizeInt4(DeviceArray& output, DeviceArray& quant, DeviceArray& scales, int n, int k) {
+ params_.tuning_ctx = TuningContext();
+ params_.stream = Stream();
+ params_.output_ = static_cast(output.ptr());
+ params_.quant_ = static_cast(quant.ptr());
+ params_.scales_ = static_cast(scales.ptr());
+ params_.zero_points_ = nullptr;
+ params_.n_ = n;
+ params_.k_ = k;
+ }
+
+ void Run() override {
+ ORT_THROW_IF_ERROR(contrib::cuda::Dequantize4Bits(
+ params_.output_,
+ params_.quant_,
+ params_.scales_,
+ params_.zero_points_,
+ params_.k_,
+ params_.n_,
+ 32,
+ params_.StreamHandle()));
+ }
+
+ private:
+ // A VectorAddOp is a callable that can process const VectorAddParams*
+ using ParamsT = DequantizeInt4Params;
+ ParamsT params_{};
+};
+
+#define REGISTER_OP(name, type) \
+ py::class_>(m, #name "_" #type) \
+ .def(py::init()) \
+ .def("SetRepeats", &name::SetRepeats) \
+ .def("Profile", &name::Profile) \
+ .def("Run", &name::Run);
+
+KE_REGISTER(m) {
+ REGISTER_OP(DequantizeInt4, half);
+ REGISTER_OP(DequantizeInt4, float);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu
new file mode 100644
index 0000000000000..fd9e9c4fd1612
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu
@@ -0,0 +1,94 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// This file serve as a simple example for adding a tunable op to onnxruntime.
+
+#include
+#include
+
+#include
+
+#include
+
+#include "core/providers/cuda/tunable/cuda_tunable.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "core/providers/cuda/cuda_stream_handle.h"
+#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
+#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh"
+#include "contrib_ops/cuda/quantization/matmul_nbits.cuh"
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+// Extend the OpParams so that all specializations have the same parameter passing interface
+template
+struct GemmBenchmarkParams : cuda::tunable::OpParams {
+ std::string Signature() const override { return std::to_string(n_); }
+
+ T* output_;
+ const T* a_;
+ const T* b_;
+ int m_;
+ int n_;
+ int k_;
+ cublasHandle_t cublas_handle;
+};
+
+template
+class GemmBenchmark : public IKernelExplorer {
+ public:
+ GemmBenchmark(DeviceArray& output, DeviceArray& a, DeviceArray& b, int m, int n, int k) {
+ params_.tuning_ctx = TuningContext();
+ params_.stream = Stream();
+ params_.output_ = static_cast(output.ptr());
+ params_.a_ = static_cast(a.ptr());
+ params_.b_ = static_cast(b.ptr());
+ params_.m_ = m;
+ params_.n_ = n;
+ params_.k_ = k;
+
+ CUBLAS_CALL_THROW(cublasCreate(&(params_.cublas_handle)));
+ CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0));
+ }
+
+ void Run() override {
+ typedef typename ToCudaType::MappedType CudaT;
+ CudaT one = ToCudaType::FromFloat(1.0f);
+ CudaT zero = ToCudaType::FromFloat(0.0f);
+ CUBLAS_CALL_THROW(cublasGemmHelper(
+ params_.cublas_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ params_.n_, params_.m_, params_.k_,
+ &one,
+ reinterpret_cast(params_.b_),
+ params_.n_,
+ reinterpret_cast(params_.a_),
+ params_.k_,
+ &zero,
+ params_.output_,
+ params_.n_,
+ device_prop_));
+ }
+
+ private:
+ // A VectorAddOp is a callable that can process const VectorAddParams*
+ using ParamsT = GemmBenchmarkParams;
+ ParamsT params_{};
+ cudaDeviceProp device_prop_;
+};
+
+#define REGISTER_OP(name, type) \
+ py::class_>(m, #name "_" #type) \
+ .def(py::init()) \
+ .def("SetRepeats", &name::SetRepeats) \
+ .def("Profile", &name::Profile) \
+ .def("Run", &name::Run);
+
+KE_REGISTER(m) {
+ REGISTER_OP(GemmBenchmark, half);
+ REGISTER_OP(GemmBenchmark, float);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu
new file mode 100644
index 0000000000000..0c7be272806b5
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu
@@ -0,0 +1,98 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// This file serve as a simple example for adding a tunable op to onnxruntime.
+
+#include
+#include
+#include
+
+#include
+
+#include "core/providers/cuda/tunable/cuda_tunable.h"
+#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
+#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh"
+#include "contrib_ops/cuda/quantization/matmul_nbits.cuh"
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+// Extend the OpParams so that all specializations have the same parameter passing interface
+template
+struct MatrixFloatInt4Params : cuda::tunable::OpParams {
+ std::string Signature() const override { return std::to_string(n_); }
+
+ T* output_;
+ const T* a_;
+ const uint8_t* b_;
+ const T* scales_;
+ const uint8_t* zero_points_;
+ int m_;
+ int n_;
+ int k_;
+};
+
+template
+class MatrixFloatInt4 : public IKernelExplorer {
+ public:
+ MatrixFloatInt4(DeviceArray& output,
+ DeviceArray& a,
+ DeviceArray& b,
+ DeviceArray& scales,
+ int m, int n, int k) {
+ params_.tuning_ctx = TuningContext();
+ params_.stream = Stream();
+ params_.output_ = static_cast(output.ptr());
+ params_.a_ = static_cast(a.ptr());
+ params_.b_ = static_cast(b.ptr());
+ params_.scales_ = static_cast(scales.ptr());
+ params_.zero_points_ = nullptr;
+ params_.m_ = m;
+ params_.n_ = n;
+ params_.k_ = k;
+ }
+
+ MatrixFloatInt4(DeviceArray& output,
+ DeviceArray& a,
+ DeviceArray& b,
+ DeviceArray& scales,
+ DeviceArray& zeropoints,
+ int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) {
+ params_.zero_points_ = static_cast(zeropoints.ptr());
+ }
+
+ void Run() override {
+ contrib::cuda::TryMatMul4Bits(
+ params_.output_,
+ params_.a_,
+ params_.b_,
+ params_.scales_,
+ params_.zero_points_,
+ params_.m_,
+ params_.n_,
+ params_.k_,
+ 32,
+ params_.StreamHandle());
+ }
+
+ private:
+ // A VectorAddOp is a callable that can process const VectorAddParams*
+ using ParamsT = MatrixFloatInt4Params;
+ ParamsT params_{};
+};
+
+#define REGISTER_OP(name, type) \
+ py::class_>(m, #name "_" #type) \
+ .def(py::init()) \
+ .def(py::init()) \
+ .def("SetRepeats", &name::SetRepeats) \
+ .def("Profile", &name::Profile) \
+ .def("Run", &name::Run);
+
+KE_REGISTER(m) {
+ REGISTER_OP(MatrixFloatInt4, half);
+ REGISTER_OP(MatrixFloatInt4, float);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py
new file mode 100644
index 0000000000000..7088039f9e531
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py
@@ -0,0 +1,78 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import sys
+from dataclasses import dataclass
+
+import kernel_explorer as ke
+import numpy as np
+from utils import dtype_to_bytes
+
+
+def dtype_to_funcs(dtype):
+ type_map = {
+ "float16": list(filter(lambda x: "DequantizeInt4_half" in x, dir(ke))),
+ "float32": list(filter(lambda x: "DequantizeInt4_float" in x, dir(ke))),
+ }
+ return type_map[dtype]
+
+
+dtypes = ["float16", "float32"]
+
+
+@dataclass
+class DequantizeInt4Metric(ke.BandwidthMetric):
+ n: int
+ k: int
+
+ def report(self):
+ return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} n={self.n} k={self.k} {self.name}"
+
+
+def profile_dequantize_int4_func(n, k, dtype, func):
+ np.random.seed(0)
+ output = np.random.rand(n, k).astype(dtype)
+ quant = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8")
+ scales = np.random.rand(n, (k + 31) // 32).astype(dtype)
+
+ output_d = ke.DeviceArray(output)
+ quant_d = ke.DeviceArray(quant)
+ scales_d = ke.DeviceArray(scales)
+ f = getattr(ke, func)
+ my_op = f(output_d, quant_d, scales_d, n, k)
+ duration_ms = my_op.Profile()
+ total_bytes = (n * k) / 2 + (n * k + n * k / 32) * dtype_to_bytes(dtype)
+
+ ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k))
+
+
+def profile_with_args(n, k, dtype, sort):
+ with ke.benchmark(sort):
+ for func in dtype_to_funcs(dtype):
+ profile_dequantize_int4_func(n, k, dtype, func)
+
+
+def profile():
+ for dt in dtypes:
+ for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
+ profile_with_args(n, k, dt, True)
+ print()
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ group = parser.add_argument_group("profile with args")
+ group.add_argument("n", type=int)
+ group.add_argument("k", type=int)
+ group.add_argument("dtype", choices=dtypes)
+ group.add_argument("--sort", action="store_true")
+
+ if len(sys.argv) == 1:
+ profile()
+ else:
+ args = parser.parse_args()
+ profile_with_args(args.n, args.k, args.dtype, args.sort)
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py
index e378f3e1cc198..8182cdb17567c 100644
--- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py
@@ -179,6 +179,7 @@ def profile_with_args(dtype, transa, transb, m, n, k, sort):
profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k)
profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k)
profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k)
+ profile_gemm_func(getattr(ke, "GemmBenchmark" + dtype_suffix), dtype, transa, transb, m, n, k)
if ke.is_hipblaslt_available():
profile_gemm_func(
getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py
new file mode 100644
index 0000000000000..9cb937a13ff27
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py
@@ -0,0 +1,132 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import sys
+from dataclasses import dataclass
+
+import kernel_explorer as ke
+import numpy as np
+from utils import dtype_to_bytes
+
+
+def dtype_to_funcs(dtype):
+ type_map = {
+ "float16": list(filter(lambda x: "MatrixFloatInt4_half" in x, dir(ke))),
+ "float32": list(filter(lambda x: "MatrixFloatInt4_float" in x, dir(ke))),
+ }
+ return type_map[dtype]
+
+
+def dtype_to_funcs_cublas(dtype):
+ type_map = {
+ "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))),
+ "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))),
+ }
+ return type_map[dtype]
+
+
+dtypes = ["float16", "float32"]
+
+
+@dataclass
+class MatrixMulMetric(ke.BandwidthMetric):
+ m: int
+ n: int
+ k: int
+
+ def report(self):
+ return (
+ f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
+ )
+
+
+@dataclass
+class MatrixFpInt4Metric(MatrixMulMetric):
+ is_symmetric: bool
+
+ def report(self):
+ return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} is_symmetric={self.is_symmetric} {self.name}"
+
+
+def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric):
+ np.random.seed(0)
+ output = np.random.rand(m, n).astype(dtype)
+ a = np.random.rand(m, k).astype(dtype)
+ b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8")
+ scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype)
+ zeropoints = np.random.rand((n * ((k + 31) // 32) + 1) // 2).astype(dtype)
+
+ output_d = ke.DeviceArray(output)
+ a_d = ke.DeviceArray(a)
+ b_d = ke.DeviceArray(b)
+ scales_d = ke.DeviceArray(scales)
+ zeropoints_d = ke.DeviceArray(zeropoints)
+ f = getattr(ke, func)
+
+ my_op = (
+ f(output_d, a_d, b_d, scales_d, m, n, k)
+ if is_symmetric
+ else f(output_d, a_d, b_d, scales_d, zeropoints_d, m, n, k)
+ )
+ duration_ms = my_op.Profile()
+ total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
+
+ ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric))
+
+
+def profile_gemm_func(m, n, k, dtype, func):
+ np.random.seed(0)
+ output = np.random.rand(m, n).astype(dtype)
+ a = np.random.rand(m, k).astype(dtype)
+ b = np.random.rand(k, n).astype(dtype)
+
+ output_d = ke.DeviceArray(output)
+ a_d = ke.DeviceArray(a)
+ b_d = ke.DeviceArray(b)
+ f = getattr(ke, func)
+ my_op = f(output_d, a_d, b_d, m, n, k)
+ duration_ms = my_op.Profile()
+ total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
+
+ ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k))
+
+
+def profile_with_args(m, n, k, dtype, sort):
+ with ke.benchmark(sort):
+ for func in dtype_to_funcs(dtype):
+ profile_matmul_fp_int4_func(m, n, k, dtype, func, True)
+
+ for func in dtype_to_funcs(dtype):
+ profile_matmul_fp_int4_func(m, n, k, dtype, func, False)
+
+ for func in dtype_to_funcs_cublas(dtype):
+ profile_gemm_func(m, n, k, dtype, func)
+
+
+def profile():
+ dims_m = [1]
+ for dt in dtypes:
+ for m in dims_m:
+ for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
+ profile_with_args(m, n, k, dt, False)
+ print()
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ group = parser.add_argument_group("profile with args")
+ group.add_argument("m", type=int)
+ group.add_argument("n", type=int)
+ group.add_argument("k", type=int)
+ group.add_argument("dtype", choices=dtypes)
+ group.add_argument("--sort", action="store_true")
+
+ if len(sys.argv) == 1:
+ profile()
+ else:
+ args = parser.parse_args()
+ profile_with_args(args.m, args.n, args.k, args.dtype, args.sort)
diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
new file mode 100644
index 0000000000000..fea9e5e8cb739
--- /dev/null
+++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
@@ -0,0 +1,229 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import argparse
+import logging
+import os
+from typing import List, Tuple
+
+import numpy as np
+import numpy.typing as npt
+import onnx
+from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
+
+from onnxruntime.capi._pybind_state import quantize_matmul_4bits
+
+from .onnx_model import ONNXModel
+from .quant_utils import attribute_to_kwarg
+
+logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class MatMul4BitsQuantizer:
+ """Perform 4b quantization of constant MatMul weights"""
+
+ def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None):
+ if nodes_to_exclude is None:
+ nodes_to_exclude = []
+ self.model = ONNXModel(model)
+ self.block_size = block_size
+ self.is_symmetric = is_symmetric
+ self.nodes_to_exclude = set(nodes_to_exclude)
+
+ @staticmethod
+ def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]:
+ for gid in range(len(graph_path) - 1, -1, -1):
+ graph = graph_path[gid]
+ for tensor in graph.initializer:
+ if tensor.name == name:
+ return tensor, graph
+ return None, None
+
+ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
+ """4b quantize fp32 weight to a blob"""
+
+ if len(fp32weight.shape) != 2:
+ raise ValueError("Current int4 block quantization only supports 2D tensors!")
+ rows, cols = fp32weight.shape
+
+ block_size = self.block_size
+ blob_size = block_size // 2
+ k_blocks = (rows + block_size - 1) // block_size
+ padded_rows = k_blocks * block_size
+ pad_len = padded_rows - rows
+ if pad_len > 0:
+ fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")
+
+ # block wise quantization, each block comes from a single column
+ packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
+ scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
+ zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
+ quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric)
+
+ return (packed, scales, zero_point)
+
+ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto:
+ """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
+
+ if node.op_type != "MatMul":
+ return node # only care about MatMul for now
+
+ logger.info(f"start to quantize {node.name} ...")
+ if node.name in self.nodes_to_exclude:
+ logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
+ return node
+
+ inputB = node.input[1] # noqa: N806
+ B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806
+ if B is None:
+ logger.info("MatMul doesn't have const weight. Skip to quantize")
+ return node # only care about constant weight
+
+ B_array = onnx.numpy_helper.to_array(B) # noqa: N806
+ if len(B_array.shape) != 2:
+ logger.info("MatMul weight is not 2D. Skip to quantize")
+ return node # can only process 2-D matrix
+
+ packed, scales, zero_points = self.int4_block_quant(B_array)
+ B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
+ B_quant.name = B.name + "_Q4"
+ for input in Bs_graph.input:
+ if input.name == inputB:
+ Bs_graph.input.remove(input)
+ break
+
+ scales_tensor = onnx.numpy_helper.from_array(scales)
+ scales_tensor.name = B.name + "_scales"
+ Bs_graph.initializer.extend([B_quant, scales_tensor])
+
+ input_names = [node.input[0], B_quant.name, scales_tensor.name]
+ if not self.is_symmetric:
+ zp_tensor = onnx.numpy_helper.from_array(zero_points)
+ zp_tensor.name = B.name + "_zero_points"
+ Bs_graph.initializer.extend([zp_tensor])
+ input_names.append(zp_tensor.name)
+
+ kwargs = {}
+ rows, cols = B_array.shape
+ kwargs["K"] = rows
+ kwargs["N"] = cols
+ kwargs["bits"] = 4
+ kwargs["block_size"] = self.block_size
+
+ matmul_q4_node = onnx.helper.make_node(
+ "MatMulNBits",
+ inputs=input_names,
+ outputs=[node.output[0]],
+ name=node.name + "_Q4" if node.name else "",
+ domain="com.microsoft",
+ **kwargs,
+ )
+
+ logger.info(f"complete quantization of {node.name} ...")
+
+ return matmul_q4_node
+
+ def _process_subgraph(self, graph_stack: List[GraphProto]):
+ new_nodes = []
+ graph = graph_stack[-1]
+
+ for node in graph.node:
+ graph_attrs = [
+ attr
+ for attr in node.attribute
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
+ ]
+ if len(graph_attrs):
+ kwargs = {}
+ for attr in node.attribute:
+ if attr.type == onnx.AttributeProto.GRAPH:
+ # recursive call to take care of sub-graph
+ graph_stack.append(attr.g)
+ kv = {attr.name: self._process_subgraph(graph_stack)}
+ elif attr.type == onnx.AttributeProto.GRAPHS:
+ value = []
+ for subgraph in attr.graphs:
+ # recursive call to take care of sub-graph
+ graph_stack.append(subgraph)
+ value.extend([self._process_subgraph(graph_stack)])
+ kv = {attr.name: value}
+ else:
+ kv = attribute_to_kwarg(attr)
+ kwargs.update(kv)
+ node = onnx.helper.make_node( # noqa: PLW2901
+ node.op_type, node.input, node.output, name=node.name, **kwargs
+ )
+
+ new_nodes.append(self._q4_matmul_node_weight(node, graph_stack))
+
+ graph.ClearField("node")
+ graph.node.extend(new_nodes)
+ graph_stack.pop()
+ return graph
+
+ def process(self):
+ # use a stack to keep track of sub-graphs
+ graph_stack = [self.model.graph()]
+ opset_import = self.model.opset_import()
+
+ has_ms_domain = False
+ for opset in opset_import:
+ if opset.domain == "com.microsoft":
+ has_ms_domain = True
+ if not has_ms_domain:
+ opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
+
+ self._process_subgraph(graph_stack)
+ self.model.clean_initializers()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="""Blockwise int4 quantization for MatMul 2D weight matrices.
+
+A weight matrix is partitioned into into blocks, where each block is a
+continguous subset inside each column. Each block is quantized into a
+set of 4b integers with a scaling factor and an optional offset.
+"""
+ )
+
+ parser.add_argument("--input_model", required=True, help="Path to the input model file")
+ parser.add_argument("--output_model", required=True, help="Path to the output model file")
+ parser.add_argument("--block_size", required=False, default=32)
+ parser.add_argument(
+ "--symmetric", required=False, default=True, help="Indicate whether to quantize the model symmetrically"
+ )
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
+ parser.set_defaults(verbose=False)
+ parser.add_argument(
+ "--nodes_to_exclude",
+ nargs="+",
+ type=str,
+ required=False,
+ default=[],
+ help="Specify the nodes to be excluded from quantization with node names",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.verbose:
+ logger.setLevel(logging.DEBUG)
+
+ input_model_path = args.input_model
+ output_model_path = args.output_model
+
+ if os.path.exists(output_model_path):
+ logger.error(f"file {output_model_path} already exists")
+ raise Exception(f"file {output_model_path} already exists")
+
+ model = onnx.load(input_model_path)
+ quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude)
+ quant.process()
+ quant.model.save_model_to_file(output_model_path, True)
diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
new file mode 100644
index 0000000000000..dc8efbbaf3709
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
@@ -0,0 +1,163 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef ORT_MINIMAL_BUILD
+
+#include "core/common/span_utils.h"
+#include "core/framework/tensor.h"
+#include "core/mlas/inc/mlas_q4.h"
+#include "core/mlas/inc/mlas.h"
+#include "core/session/inference_session.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/framework/test_utils.h"
+#include "test/optimizer/graph_transform_test_builder.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
+#include "core/util/qmath.h"
+#include "contrib_ops/cpu/quantization/dequantize_blockwise.h"
+
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "gmock/gmock.h"
+
+namespace onnxruntime {
+namespace test {
+
+void QuantizeDequantize(std::vector