Skip to content

Commit

Permalink
support matmulnbits with hqq
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Feb 7, 2024
1 parent 36d2236 commit 2e58ea2
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 80 deletions.
31 changes: 4 additions & 27 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
// pre-packed and block-compacted into int4
//

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "matmul_nbits.h"

Check warning on line 10 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:10: Include the directory when naming header files [build/include_subdir] [4]
#include "core/common/status.h"
#include "core/framework/float16.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "matmul_nbits.cuh"
#include "dequantize_blockwise.cuh"
Expand All @@ -19,29 +19,6 @@ namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

template <typename T>
class MatMulNBits final : public CudaKernel {
public:
MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op,"
" additional bits support is planned.");
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t K_;
int64_t N_;
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_blk_{true};
};

template <typename T>
Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
Expand Down Expand Up @@ -162,7 +139,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<MLFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<int32_t>()}),
MatMulNBits<MLFloat16>);

} // namespace cuda
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

//
// This module define MatMulFp32Q4 operator, it is basically
// matmul float32 with right hand side being a 2-D matrix
// pre-packed and block-compacted into int4
//
#pragma once
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

Check warning on line 17 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h:17: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

template <typename T>
class MatMulNBits final : public CudaKernel {
public:
MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t K_;
int64_t N_;
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_blk_{true};
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
22 changes: 8 additions & 14 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3339,22 +3339,14 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.
Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits
For a block blob. It is stored in format:
struct Blob {
uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
}
Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits
Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4
)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
Expand All @@ -3373,12 +3365,14 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
"type T1.",
AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional data blob", "T2")
.Input(1, "B", "1 or 2 dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
.Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional)
.Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional)
.Input(4, "g_idx", "group_idx for gptq", "T2", OpSchema::Optional)
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.")

Check warning on line 3374 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3374: Lines should be <= 120 characters long [whitespace/line_length] [2]
.TypeConstraint("T3", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)", "tensor(float16)"}, "Constrain quantized zero point types to uint8/uint32/int32/float16.")

Check warning on line 3375 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3375: Lines should be <= 120 characters long [whitespace/line_length] [2]
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
Expand Down
Loading

0 comments on commit 2e58ea2

Please sign in to comment.