diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h deleted file mode 100644 index 11b5447d65ed2..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -namespace onnxruntime { -namespace contrib { - -#if defined(_MSC_VER) -#define FORCEINLINE __forceinline -#else -#define FORCEINLINE __attribute__((always_inline)) inline -#endif - -template -struct alignas(1) BlockwiseQuantBlock { - static_assert(block_size % 8 == 0); - - uint8_t blob_data[block_size / 8 * bits]; - - FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const; - FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const; - - FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N); - FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N); -}; - -template -struct alignas(1) BlockwiseQuantBlock { - static_assert(block_size % 8 == 0); - - uint8_t blob_data[block_size / 2]; - - FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const { - for (int i = 0; i < block_size; i += 2) { - T zp_t = static_cast(float(zp)); - if (k_idx + i < K) { - T x0 = static_cast(float(blob_data[i / 2] & 0xF)); - dst[i] = scale * (x0 - zp_t); - } - if (k_idx + i + 1 < K) { - T x1 = static_cast(float(blob_data[i / 2] >> 4)); - dst[i + 1] = scale * (x1 - zp_t); - } - } - } - - FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const { - constexpr uint8_t zp = 8; - dequant(dst, scale, zp, k_idx, K); - } - - FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) { - float min = static_cast(*src); - float max = static_cast(*src); - int32_t klen = std::min(block_size, K - k_idx); - for (int32_t kk = 0; kk < klen; kk++) { - const float v = static_cast(src[N * kk]); - if (v < min) min = v; - if (v > max) max = v; - } - min = std::min(min, 0.0f); - max = std::max(max, 0.0f); - - const float scale = (max - min) / ((1 << 4) - 1); - scale_block = static_cast(scale); - - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - float zero_point_fp = min; - if (scale != 0.0f) { - zero_point_fp = 0.f - min / scale; - } - - // Handle any clamping - if (zero_point_fp < 0.0f) { - zp = 0; - } else if (zero_point_fp > 15.0f) { - zp = 15; - } else { - zp = (uint8_t)roundf(zero_point_fp); - } - - for (int32_t kk = 0; kk < klen; kk += 2) { - const float v0 = static_cast(src[N * kk]); - const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); - - const float v1 = static_cast((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f); - const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); - - blob_data[kk / 2] = vi0 | (vi1 << 4); - } - } - - FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) { - float amax = 0.0f; // abs(max) - float max = 0.0f; - - int32_t klen = std::min(block_size, K - k_idx); - - for (int32_t kk = 0; kk < klen; kk++) { - const float v = static_cast(src[N * kk]); - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float scale = max / (-8.f); - scale_block = static_cast(scale); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - - for (int32_t kk = 0; kk < klen; kk += 2) { - const float v0 = src[N * kk] * reciprocal_scale; - const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f))); - - const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0; - const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f))); - - blob_data[kk / 2] = vi0 | (vi1 << 4); - } - } -}; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h deleted file mode 100644 index 8811e5649fc19..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "blockwise_quant_block.h" - -#include - -#include "core/common/safeint.h" -#include "core/framework/float16.h" -#include "core/platform/threadpool.h" -#include - -namespace onnxruntime { -namespace contrib { - -template -void QuantizeBlockwise( - uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] - const T* src, // shape: [K, N] - T* scale, // shape: [N * block_per_K] - uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2] - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - BlockwiseQuantBlock* dst_blob = - reinterpret_cast*>(dst); - - int32_t block_per_K = (K + block_size - 1) / block_size; - int32_t total_block_count = N * block_per_K; - - std::vector zero_points_tmp; // to avoid race condition - (void)zero_points_tmp; - uint8_t* zero_points_tmp_ptr = zero_points; - if (bits <= 4 && zero_points != nullptr) { - zero_points_tmp.resize(total_block_count, 0); - zero_points_tmp_ptr = zero_points_tmp.data(); - } - - concurrency::ThreadPool::TryBatchParallelFor( - thread_pool, - total_block_count, - [&](ptrdiff_t block_idx) { - int32_t n = static_cast(block_idx / block_per_K); - int32_t k_block_idx = static_cast(block_idx % block_per_K); - int32_t k = k_block_idx * block_size; - BlockwiseQuantBlock* blob_ptr = dst_blob + block_idx; - size_t offset = SafeInt(k) * N + n; - if (nullptr != zero_points_tmp_ptr) { - blob_ptr->quant(src + offset, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N); - } else { - blob_ptr->quant(src + offset, scale[block_idx], k, K, N); - } - }, - 0); - - if (bits <= 4 && zero_points != nullptr) { // compact zero points - for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) { - zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4)); - } - if (total_block_count & 1) { - zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] & 0xf0) | zero_points_tmp[total_block_count - 1]; - } - } -} - -#define QuantizeBlockwise4Bits(block_size) \ - QuantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); - -template -void QuantizeBlockwise( - uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] - const T* src, // shape: [K, N] - T* scale, // shape: [N, block_per_K] - uint8_t* zero_points, // shape: [N, block_per_K] - int32_t block_size, - int32_t bits, - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); - - if (16 == block_size) { - QuantizeBlockwise4Bits(16); - } else if (32 == block_size) { - QuantizeBlockwise4Bits(32); - } else if (64 == block_size) { - QuantizeBlockwise4Bits(64); - } else if (128 == block_size) { - QuantizeBlockwise4Bits(128); - } else if (256 == block_size) { - QuantizeBlockwise4Bits(256); - } else { - ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); - } -} - -#undef QuantizeBlockwise4Bits - -template -void DequantizeBlockwise( - T* dst, // shape: [N, K] - const uint8_t* src, // shape: [N, block_per_K, block_blob_size] - const T* scale, // shape: [N, block_per_K] - const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - int32_t block_per_K = (K + block_size - 1) / block_size; - int32_t task_count = N * block_per_K; - - const BlockwiseQuantBlock* src_blob = - reinterpret_cast*>(src); - - concurrency::ThreadPool::TryBatchParallelFor( - thread_pool, - task_count, - [&](ptrdiff_t task_idx) { - int32_t n = static_cast(task_idx / block_per_K); - int32_t k_block_idx = static_cast(task_idx % block_per_K); - int32_t k = k_block_idx * block_size; - const BlockwiseQuantBlock* blob_ptr = src_blob + task_idx; - size_t offset = SafeInt(n) * K + k; - if (nullptr != zero_points) { - if constexpr (bits > 4) { // zero point is stored with a byte - blob_ptr->dequant(dst + offset, scale[task_idx], zero_points[task_idx], k, K); - } else { // zero points is stored with 4bits - uint8_t zp = zero_points[task_idx / 2]; - zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf); - blob_ptr->dequant(dst + offset, scale[task_idx], zp, k, K); - } - } else { - blob_ptr->dequant(dst + offset, scale[task_idx], k, K); - } - }, - 0); -} - -#define DequantizeBlockwise4Bits(block_size) \ - DequantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); - -template -void DequantizeBlockwise( - T* dst, // [N, K] - const uint8_t* src, // [N, block_per_K, block_blob_size] - const T* scale, // [N, block_per_K] - const uint8_t* zero_points, // [N, block_per_K] - int32_t block_size, - int32_t bits, - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); - - if (16 == block_size) { - DequantizeBlockwise4Bits(16); - } else if (32 == block_size) { - DequantizeBlockwise4Bits(32); - } else if (64 == block_size) { - DequantizeBlockwise4Bits(64); - } else if (128 == block_size) { - DequantizeBlockwise4Bits(128); - } else if (256 == block_size) { - DequantizeBlockwise4Bits(256); - } else { - ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); - } -} - -#undef DequantizeBlockwise4Bits - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 57aada94be39c..c72d811170a27 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -5,8 +5,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "dequantize_blockwise.h" -#include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace contrib { @@ -18,6 +17,9 @@ class MatMulNBits final : public OpKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); } Status Compute(OpKernelContext* context) const override; @@ -27,6 +29,7 @@ class MatMulNBits final : public OpKernel { int64_t N_; int64_t block_size_; int64_t nbits_; + bool column_wise_quant_{true}; }; Status MatMulNBits::Compute(OpKernelContext* ctx) const { @@ -46,15 +49,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - DequantizeBlockwise(tmp_b_data_ptr.get(), - b_data, - scales_data, - zero_points_data, - static_cast(block_size_), - static_cast(nbits_), - static_cast(N_), - static_cast(K_), - thread_pool); + + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + zero_points_data, // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 8c328d00b44d0..7921315ab52e1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -18,6 +18,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { + __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); @@ -61,15 +62,19 @@ __global__ void Dequantize4BitsKernel( const T* scale_data, const uint8_t* zero_points, int block_size, + int blocks_per_K, int blocks_per_threadblock, int shift) { int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + int n_idx = block_id / blocks_per_K; + int kb_idx = block_id % blocks_per_K; int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); uint8_t zp = 8; if (zero_points) { - zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f); + zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); } output = output + element_offset; @@ -100,6 +105,7 @@ Status Dequantize4Bits( scales_data, zero_points, block_size, + blocks_per_K, blocks_per_threadblock, shift); @@ -126,6 +132,244 @@ template Status Dequantize4Bits( int block_size, cudaStream_t stream); + +/////////////////////////////////////////////////////////////////////////////// +// A more general block-wise dequantization implementation that supports +// different block sizes and block orientations (row-wise/column-wise). + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +/** + * @brief Blockwise quantization constants + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlkQuantTraits { + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D; +}; + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ +void dequantizeThread(ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int rows, + int columns, + int thrd_row_blks) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // !! 4b specific code + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + const auto meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + // quantized matrix is stored in column major, packed by column + const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + + int32_t r_blk_idx = static_cast(block_idx % thrd_row_blks); + int32_t c_blk_idx = static_cast(block_idx / thrd_row_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets + const ElementT scale_buf[2] = { + scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], + ((r/QuantBlk::kRow) < (meta_rows - 1)) + ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] + : static_cast(0.0f)}; + const uint8_t zp_pair = (zero_points == nullptr) + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; + const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), + (-scale_buf[1]) * static_cast(zp_buf[1])}; + + for (int32_t j = c; j < c_end; ++j) { + const uint8_t* q_ptr = weights + j * q_rows; + for (int32_t i = r; i < (r_end - 1); i += 2) { + const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; + + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[i / 2]; + + if constexpr (std::is_same::value) { + half2 scale_half2 = {scale0, scale1}; + half2 zp_adjust2 = {adjust0, adjust1}; + + half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)}; + half2 results = v * scale_half2 + zp_adjust2; + + dst[j * rows + i] = results.x; + dst[j * rows + (i + 1)] = results.y; + } else { + static_assert(std::is_same::value, "Only float and half are supported!"); + const uint8_t vi0 = vi & 0xf; + const uint8_t vi1 = vi >> 4; + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; + } + } + + if ((r_end & 1) && (r_end > r)) { + const auto scale0 = scale_buf[(r_end - 1 - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(r_end - 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[(r_end - 1) / 2]; + const uint8_t vi0 = vi & 0xf; + + dst[j * rows + (r_end - 1)] = static_cast(vi0) * scale0 + adjust0; + } + } +} + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto grids = (total_thrd_blks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; + dequantizeThread<<>>( + dst, + weights, + scales, + zero_points, + rows, + columns, + thrd_row_blks); +} + + +template +Status +DequantizeBlockwise4b( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream) { + switch (block_size) { + case 16: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 32: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 64: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 128: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + case 256: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + default: + // Only block size 16, 32, 64, 128, 256 are supported. + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL, + "Unsupported block size for blockwise quantization."); + } +} + +template +Status DequantizeBlockwise4b( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +template +Status DequantizeBlockwise4b( + half* dst, + const uint8_t* src, + const half* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index 741ce1e735b42..f9c09c55fd893 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -18,6 +18,33 @@ Status Dequantize4Bits( int block_size, cudaStream_t stream); + +/** + * @brief Dequantize a block-wise quantized matrix, and store the result in a + * column major matrix for use in subsequent GEMM. This implementation supports + * columnwise and rowwise block orientation. + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] qelements pointer to the quantized elements, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] block_size size of the quantized block + * @param[in] columnwise whether the quantized matrix is columnwise or rowwise quantized + * @param[in] rows + * @param[in] columns + */ +template +Status DequantizeBlockwise4b( + T* dst, + const uint8_t* qelements, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 14a8163fef500..5b0e61e197014 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -27,6 +27,9 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); } Status ComputeInternal(OpKernelContext* context) const override; @@ -36,6 +39,7 @@ class MatMulNBits final : public CudaKernel { int64_t N_; int64_t block_size_; int64_t nbits_; + bool column_wise_quant_blk_{true}; }; template @@ -50,8 +54,6 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); - ORT_ENFORCE(nbits_ == 4, "only 4 bits is supported now"); - typedef typename ToCudaType::MappedType CudaT; constexpr bool transa = false; @@ -81,14 +83,32 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); auto* b_data = b_data_ptr.get(); - ORT_RETURN_IF_ERROR(Dequantize4Bits(reinterpret_cast(b_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(K_padded), - SafeInt(N_), - SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + if (column_wise_quant_blk_) { + // column-wise block + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); T* b_data_cpu = new T[K_ * N_]; diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 4c3c345076416..f2600a506285d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -96,6 +96,9 @@ __global__ void MatMulFloatInt4Kernel( constexpr int k_per_iter = 256; int k_iter = k / k_per_iter; + // blocks_per_k is the number of scales and zero points on the k dim + const int b_zp_k = (blocks_per_K + 1)/ 2; + extern __shared__ char shared_buffer[]; // load scale to shared buffer @@ -105,30 +108,39 @@ __global__ void MatMulFloatInt4Kernel( for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { b_scale_vec[i] = scales_data[offset + i]; } - for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K / 2; i += kColsPerThreadBlock * kWarpSize) { - b_zp_vec[i] = zero_points != nullptr ? zero_points[offset / 2 + i] : uint8_t(0x88); + + int zp_offset = n_block_id * kColsPerThreadBlock * b_zp_k; + for (int i = thread_id; i < kColsPerThreadBlock * b_zp_k; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[zp_offset + i] : uint8_t(0x88); } __syncthreads(); a_data += m_id * k; b_data_quant += n_id * blocks_per_K * (block_size / 2); + const int scale_col_offset = warp_id * blocks_per_K; + const int zp_col_offset = warp_id * b_zp_k; + float sum = 0.f; int k_id = 0; for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { - uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4)); - int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; - T scale = b_scale_vec[block_idx]; - uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point + uint32_t value = *(reinterpret_cast(b_data_quant + (t_k >> 1))); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } // handle reminder if (k_id + lane_id * 8 < k) { + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); - int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; - T scale = b_scale_vec[block_idx]; - uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index f3bc2a2434ab3..7c7b729117e4a 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -39,7 +39,7 @@ typedef enum { * @brief Computes the number of bytes required to pack and int4-quantize * a weight matrix * @param QType type of block quantization - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @return size of the packing buffer, 0 if the operation is not yet supported. */ @@ -53,11 +53,11 @@ MlasQ4GemmPackBSize( /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * + * * @param QType type of block quantization * @param PackedBuf destination buffer * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @param ldb leading dimension of B */ @@ -257,14 +257,14 @@ MlasBlockwiseQuantMetaShape( * matrix shape [rows, columns], compute the shape of the * quantized matrix [q_rows, q_cols]. The quantized matrix * is in column major layout, with bits packed on the column. - * - * @tparam T - * @param block_size - * @param columnwise - * @param rows - * @param columns - * @param q_rows - * @param q_cols + * + * @tparam T + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols */ template void @@ -283,21 +283,22 @@ MlasBlockwiseQuantizedShape( * parameters (scales, zero points) are packed into separate matrices * all in column major layout for faster access during subsequent matrix * multiplication. - * + * * @tparam ElementT type of the input matrix element, usually floating point - * + * @tparam qbits number of bits used for quantization, 4 for int4 + * * @param dst points to the quantized matrix, shape [rows, columns] column major - * @param scales points to the scales matrix, column major + * @param scales points to the scales matrix, column major * @param zero_points points to the zero_points matrix, column major * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row - * @param rows - * @param columns - * @param leading_dimension - * @param thread_pool + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool */ -template +template void MlasQuantizeBlockwise( uint8_t* dst, @@ -318,19 +319,21 @@ MlasQuantizeBlockwise( * parameters (scales, zero points) are from separate matrices packed * in column major layout. Output is a floating point matrix in column * major layout for faster access during subsequent matrix multiplication. - * + * * @tparam ElementT type of the dequantized matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * * @param dst points to dequantized matrix shape [rows, columns] column major * @param src points to quantized matrix, column major * @param scales points to quantization scales, column major * @param zero_points points to quantization zero points, column major * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row - * @param rows - * @param columns - * @param thread_pool + * @param rows + * @param columns + * @param thread_pool */ -template +template void MlasDequantizeBlockwise( ElementT* dst, diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 24a2212ba0714..fbd1030de8ab7 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -364,7 +364,7 @@ range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) } else { zp = (uint8_t)roundf(zero_point_fp); } - scale = static_cast(scale_f); + scale = ScaleT(scale_f); } template @@ -377,7 +377,7 @@ range2scale(float min, float max, ScaleT& scale) max = fabsf(max) > fabsf(min) ? max : min; - scale = static_cast(max / mid_fp); + scale = ScaleT(max / mid_fp); }; @@ -773,7 +773,7 @@ MlasBlockwiseQuantizedShape( ); -template +template void MlasQuantizeBlockwise( uint8_t* dst, @@ -791,50 +791,50 @@ MlasQuantizeBlockwise( switch (block_size) { case 16: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 32: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 64: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 128: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 256: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; @@ -847,7 +847,7 @@ MlasQuantizeBlockwise( template void -MlasQuantizeBlockwise( +MlasQuantizeBlockwise( uint8_t* dst, float* scales, uint8_t* zero_points, @@ -860,8 +860,23 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + -template +template void MlasDequantizeBlockwise( T* dst, @@ -878,46 +893,46 @@ MlasDequantizeBlockwise( switch (block_size) { case 16: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 32: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 64: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 128: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 256: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; @@ -929,7 +944,7 @@ MlasDequantizeBlockwise( template void -MlasDequantizeBlockwise( +MlasDequantizeBlockwise( float* dst, const uint8_t* src, const float* scales, diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 04dfa9b51e112..ff76887e917cd 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -5,7 +5,7 @@ #include #include -#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" #include "core/util/thread_utils.h" @@ -53,15 +53,16 @@ void QuantizeMatMul4BitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - contrib::QuantizeBlockwise( - static_cast(dst_buf.ptr), - static_cast(src_buf.ptr), - static_cast(scale_buf.ptr), - is_symmetric ? nullptr : static_cast(zp_buf.ptr), + MlasQuantizeBlockwise( + reinterpret_cast(dst_buf.ptr), + reinterpret_cast(scale_buf.ptr), + is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), + reinterpret_cast(src_buf.ptr), block_size, - 4, - N, + true, K, + N, + N, tp.get()); } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py index 9cb937a13ff27..111e156cd6d01 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py @@ -56,7 +56,7 @@ def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): a = np.random.rand(m, k).astype(dtype) b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype) - zeropoints = np.random.rand((n * ((k + 31) // 32) + 1) // 2).astype(dtype) + zeropoints = np.random.rand(n * (((k + 31) // 32 + 1) // 2)).astype(dtype) output_d = ke.DeviceArray(output) a_d = ke.DeviceArray(a) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index fea9e5e8cb739..1c3c212b54fa4 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -61,7 +61,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: # block wise quantization, each block comes from a single column packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) return (packed, scales, zero_point) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index dc8efbbaf3709..918ee0e6eb976 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -14,7 +14,6 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" -#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" #include #include @@ -25,6 +24,8 @@ namespace onnxruntime { namespace test { +static constexpr int QBits = 4; + void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -35,27 +36,29 @@ void QuantizeDequantize(std::vector& raw_vals, OrtThreadPoolParams to; auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); - contrib::QuantizeBlockwise( + + MlasQuantizeBlockwise( quant_vals.data(), - raw_vals.data(), scales.data(), zp != nullptr ? zp->data() : nullptr, + raw_vals.data(), block_size, - 4, - N, + true, K, + N, + N, tp.get()); // Note that input1_f_vals is NxK after dequant - contrib::DequantizeBlockwise( - raw_vals.data(), - quant_vals.data(), - scales.data(), - zp != nullptr ? zp->data() : nullptr, - block_size, - 4, - N, - K, + MlasDequantizeBlockwise( + raw_vals.data(), // dequantized output + quant_vals.data(), // quantized input + scales.data(), // quantization scales + zp != nullptr ? zp->data() : nullptr, // quantization zero points + block_size, // quantization block size + true, // columnwise quantization + K, // number of rows + N, // number of columns tp.get()); } @@ -69,13 +72,21 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); #endif - int64_t block_per_k = (K + block_size - 1) / block_size; - int64_t number_of_block = block_per_k * N; - int64_t block_blob_size = block_size * 4 / 8; - int64_t buf_size = number_of_block * (block_size * 4 / 8); - std::vector input1_vals(buf_size); - std::vector scales(number_of_block); - std::vector zp((N * block_per_k + 1) / 2); + int meta_rows; + int meta_cols; + MlasBlockwiseQuantMetaShape((int)block_size, true, (int)K, (int)N, meta_rows, meta_cols); + + int q_rows; + int q_cols; + MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); + + std::vector input1_vals(q_rows * q_cols); + std::vector scales(meta_rows * meta_cols); + + // TODO!! THIS SHOULD BE PROVIDED BY MLAS + // sub 8b packing always happen on the column dimension + const int packed_meta_rows = (meta_rows * QBits + 7) / 8; + std::vector zp(packed_meta_rows * meta_cols); QuantizeDequantize(input1_f_vals, input1_vals, @@ -100,13 +111,13 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop test.AddAttribute("K", K); test.AddAttribute("N", N); test.AddAttribute("block_size", block_size); - test.AddAttribute("bits", 4); + test.AddAttribute("bits", QBits); if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); - test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); - test.AddInput("scales", {N * block_per_k}, ToFloat16(scales), true); + test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + test.AddInput("scales", {meta_cols * meta_rows}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -117,10 +128,10 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else { test.AddInput("A", {M, K}, input0_vals, false); - test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); - test.AddInput("scales", {N * block_per_k}, scales, true); + test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + test.AddInput("scales", {meta_cols * meta_rows}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); } test.AddOutput("Y", {M, N}, expected_vals); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index 1683cee856907..f836da8277bb8 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -96,7 +96,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, + columnwise, rows, columns, threadpool_ptr); MlasTranspose(dequant_buf, transposed, columns, rows); @@ -104,7 +105,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, + columnwise, rows, columns, columns, threadpool_ptr); for (int c = 0; c < columns; c++) { for (int r = 0; r < rows; r += 2) { diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py index e03a0167d070a..765825d4b86e3 100644 --- a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -38,8 +38,8 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=matrix_float_padded.dtype) - zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") matrix_float_padded = np.transpose(matrix_float_padded) for n in range(cols): @@ -61,10 +61,12 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i zp = min(15, max(0, round(zero_point_fp))) reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 - block_idx = n * k_blocks + k_id // block_size - scales[block_idx] = scale - zp_pair = zero_point[block_idx // 2] - zero_point[block_idx // 2] = ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + block_idx = k_id // block_size + scales[n, block_idx] = scale + zp_pair = zero_point[n, block_idx // 2] + zero_point[n, block_idx // 2] = ( + ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + ) blk_int0 = np.clip( np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), @@ -76,7 +78,7 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i 0, 15, ).astype("uint8") - packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + packed[n, block_idx] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) return (packed, scales, zero_point) @@ -88,8 +90,8 @@ def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int k_blocks = (rows + block_size - 1) // block_size packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=matrix_float.dtype) - zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") from onnxruntime.capi._pybind_state import quantize_matmul_4bits quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) @@ -116,24 +118,22 @@ def test_quantize_blockwise_4bits(self): assert np.allclose(zero_point_ref, zero_point) for c in range(quant_value_ref.shape[0]): for k in range(quant_value_ref.shape[1]): - block_idx = c * quant_value_ref.shape[1] + k - zp_idx = block_idx // 2 assert np.allclose( dequantize_blockwise_4bits( - quant_value_ref[c][k], - scales_ref[block_idx], - (zero_point_ref[zp_idx] >> 4) - if (block_idx & 1) - else (zero_point_ref[zp_idx] & 0x0F), + quant_value_ref[c, k], + scales_ref[c, k], + (zero_point_ref[c, k // 2] >> 4) + if (k & 1) + else (zero_point_ref[c, k // 2] & 0x0F), min(block_size, rows - k * block_size), ), dequantize_blockwise_4bits( - quant_value[c][k], - scales[block_idx], - (zero_point[zp_idx] >> 4) if (block_idx & 1) else (zero_point[zp_idx] & 0x0F), + quant_value[c, k], + scales[c, k], + (zero_point[c, k // 2] >> 4) if (k & 1) else (zero_point[c, k // 2] & 0x0F), min(block_size, rows - k * block_size), ), - atol=1.2 * abs(scales[block_idx]), + atol=1.2 * abs(scales[c, k]), )