From 6a3caa6a51fde87fde52e999c11fb015447f9af3 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 27 Feb 2024 11:21:39 +0800 Subject: [PATCH] format --- docs/ContribOperators.md | 17 ++++++-- .../cpu/quantization/matmul_nbits.cc | 40 +++++++++++++++---- .../cpu/quantization/matmul_nbits_impl.cc | 4 +- .../cuda/quantization/matmul_nbits.cc | 6 --- .../cuda/quantization/matmul_nbits.h | 2 +- .../core/graph/contrib_ops/contrib_defs.cc | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 4 +- 7 files changed, 53 insertions(+), 22 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 7536923323060..feb579d91c191 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2810,12 +2810,21 @@ This version of the operator has been available since version 1 of the 'com.micr Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 621ba5bd936e9..602dd98d8c0d6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + #include #include + #include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" @@ -12,7 +15,6 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #ifdef ORT_NEURAL_SPEED #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" @@ -54,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } } // namespace +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -63,6 +76,17 @@ class MatMulNBits final : public OpKernel { block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -92,6 +116,8 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; @@ -109,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } #if defined(ORT_NEURAL_SPEED) if (!all_constant_) { @@ -216,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); @@ -261,8 +288,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { #endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); - const Tensor* reorder_idx = ctx->Input(4); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); @@ -289,8 +316,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(false); const bool has_single_b_matrix = - (reorder_idx_data == nullptr) && - (!zero_points || !zero_points->IsDataType()) && + (!act_order_) && (!zero_point_is_not_quant_) && std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); if (has_single_b_matrix) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index d86b54d397341..f92e59e990ba5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + #include #include #include #include #include + #include "core/common/common.h" #include "core/framework/float16.h" #include "core/providers/common.h" @@ -69,7 +71,7 @@ void DequantizeBlockwise( const uint8_t* quant_data, // quantized input const inputT* scales_data, // quantization scales const zeroT* zero_points, // quantization zero points - const int32_t* reorder_idx, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization int32_t block_size, // quantization block size bool, // columnwise quantization or row-wise int32_t K, // number of rows in quantized input diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index f0bb6d7459850..bc5a3e3557f49 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,11 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulNBits operator, it is basically -// matmul float with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// #include "contrib_ops/cuda/quantization/matmul_nbits.h" @@ -144,7 +139,6 @@ delete[] b_data_cpu; UseTF32())); } - return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index caf14bbf70d3e..f5c2c6c4e4fdf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -3,7 +3,7 @@ // // This module define MatMulNBits operator, it is basically -// matmul float32 with right hand side being a 2-D matrix +// matmul float with right hand side being a 2-D matrix // pre-packed and block-compacted into int4 // #pragma once diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4fca0ac625958..f06a3785f362d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3350,7 +3350,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. 4bit example: |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) - - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t seperately. no bits are wasted. + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. 3bit example: |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. The last uint_8 may have some bits unused. diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 1d465d2a583ea..d294fd4e2b0e0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -67,8 +67,8 @@ void QuantizeDequantize(std::vector& raw_vals, } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, bool has_g_idx = false, bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { - std::cerr << M << " " << N << " " << K << " " << block_size << " " << has_zeropoint << " " << use_float16 << " " << has_g_idx << " " << zp_is_4bit << " " << std::endl; + bool has_zeropoint, bool use_float16, bool has_g_idx = false, + bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { zp_is_4bit = zp_is_4bit | has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f));