Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Feb 28, 2024
1 parent 430bc4f commit 6a3caa6
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 22 deletions.
17 changes: 13 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2810,12 +2810,21 @@ This version of the operator has been available since version 1 of the 'com.micr

Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
- for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
4bit example:
|.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
- for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
3bit example:
|.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
The last uint_8 may have some bits unused.


Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4
Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
- [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
If zero_points has same type as A, it's not packed and has the same shape as Scales.

#### Version

Expand Down
40 changes: 33 additions & 7 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#include <cstdint>
#include <type_traits>

#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/common/safeint.h"
Expand All @@ -12,7 +15,6 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#ifdef ORT_NEURAL_SPEED
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
Expand Down Expand Up @@ -54,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
}
} // namespace

bool GetType(const NodeArg& node_arg, int32_t& type) {
type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
const auto* type_proto = node_arg.TypeAsProto();
if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) {
return false;
}

type = type_proto->tensor_type().elem_type();
return true;
}

class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info)
Expand All @@ -63,6 +76,17 @@ class MatMulNBits final : public OpKernel {
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
const auto& node = info.node();
auto input_defs = node.InputDefs();
// g_idx
if (input_defs.size() > 4) {
act_order_ = true;
}
int32_t type;
if (input_defs.size() > 3 && GetType(*input_defs[3], type)) {
zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
}

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
Expand Down Expand Up @@ -92,6 +116,8 @@ class MatMulNBits final : public OpKernel {
const size_t N_;
const size_t block_size_;
const size_t nbits_;
bool act_order_{false};
bool zero_point_is_not_quant_{false};
const int64_t accuracy_level_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
Expand All @@ -109,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;

if (act_order_ || zero_point_is_not_quant_) {
return Status::OK();
}
#if defined(ORT_NEURAL_SPEED)

if (!all_constant_) {
Expand Down Expand Up @@ -216,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

Expand Down Expand Up @@ -261,8 +288,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
#endif // defined(ORT_NEURAL_SPEED)

const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* reorder_idx = ctx->Input<Tensor>(4);
const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input<Tensor>(3) : nullptr;
const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input<Tensor>(4) : nullptr;

const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
Expand All @@ -289,8 +316,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t lda = helper.Lda(false);

const bool has_single_b_matrix =
(reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<float>()) &&
(!act_order_) && (!zero_point_is_not_quant_) &&
std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });

Check warning on line 320 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:320: Lines should be <= 120 characters long [whitespace/line_length] [2]

if (has_single_b_matrix) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <type_traits>

#include "core/common/common.h"
#include "core/framework/float16.h"
#include "core/providers/common.h"
Expand Down Expand Up @@ -69,7 +71,7 @@ void DequantizeBlockwise(
const uint8_t* quant_data, // quantized input
const inputT* scales_data, // quantization scales
const zeroT* zero_points, // quantization zero points
const int32_t* reorder_idx, // quantization zero points
const int32_t* reorder_idx, // reorder_idx for groupwise quantization
int32_t block_size, // quantization block size
bool, // columnwise quantization or row-wise
int32_t K, // number of rows in quantized input
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

//
// This module define MatMulNBits operator, it is basically
// matmul float with right hand side being a 2-D matrix
// pre-packed and block-compacted into int4
//

#include "contrib_ops/cuda/quantization/matmul_nbits.h"

Expand Down Expand Up @@ -144,7 +139,6 @@ delete[] b_data_cpu;
UseTF32()));
}


return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//
// This module define MatMulNBits operator, it is basically
// matmul float32 with right hand side being a 2-D matrix
// matmul float with right hand side being a 2-D matrix
// pre-packed and block-compacted into int4
//
#pragma once
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3350,7 +3350,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
- for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
4bit example:
|.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
- for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t seperately. no bits are wasted.
- for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
3bit example:
|.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
The last uint_8 may have some bits unused.
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
}

void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,
bool has_zeropoint, bool use_float16, bool has_g_idx = false, bool zp_is_4bit = true, float fp16_abs_error = 0.02f) {
std::cerr << M << " " << N << " " << K << " " << block_size << " " << has_zeropoint << " " << use_float16 << " " << has_g_idx << " " << zp_is_4bit << " " << std::endl;
bool has_zeropoint, bool use_float16, bool has_g_idx = false,
bool zp_is_4bit = true, float fp16_abs_error = 0.02f) {
zp_is_4bit = zp_is_4bit | has_g_idx;
RandomValueGenerator random{1234};
std::vector<float> input0_vals(random.Gaussian<float>(std::vector<int64_t>({M, K}), 0.0f, 0.25f));
Expand Down

0 comments on commit 6a3caa6

Please sign in to comment.