diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f9d9b13f0fedc..de0f084b1cf11 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DequantizeLinearBlockWise); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); @@ -243,81 +244,82 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, - // add more kernels here - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // add more kernels here + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_SPARSE_TENSORS) - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to main backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to main backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif #ifdef ENABLE_TRAINING_OPS - // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or - // 2). this is needed by inference for other purpose. - BuildKernelCreateInfo, + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise.cc b/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise.cc new file mode 100644 index 0000000000000..e19aba8f8d134 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define DequantizeLinearBlockWise operator, it is basically +// dequantize input tensor and unpack it into float/half tensor. +// +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/op_kernel.h" + +#include "core/mlas/inc/mlas_q4.h" +#include "core/providers/common.h" +#include "dequantizeLinear_blockwise_imp.h" + +namespace onnxruntime { +namespace contrib { + +class DequantizeLinearBlockWise final : public OpKernel { + public: + DequantizeLinearBlockWise(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(K_ * N_ > 0, "K and N must be greater than 0."); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("axis", &axis_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("packing", &packing_)); + if (packing_ == "default") { + ORT_ENFORCE(axis_ == 1, "axis_ should be 1 for default packing."); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for DequantizeLinearBlockWise op," + " additional bits support is planned."); + } else if (packing_ == "gptq") { + ORT_ENFORCE(axis_ == 0, "axis_ should be 0 for gptq packing."); + ORT_ENFORCE(nbits_ == 4, "nbits_ should be 4."); + } else if (packing_ == "hqq") { + ORT_ENFORCE(axis_ == 0, "axis_ should be 0 for hqq packing."); + ORT_ENFORCE(nbits_ == 4, "nbits_ should be 4."); + } else { + ORT_THROW("Unsupported packing type: ", packing_); + } + } + + Status Compute(OpKernelContext* context) const override; + Status ComputeGPTQ(OpKernelContext* context) const; + Status ComputeHQQ(OpKernelContext* context) const; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t axis_; + int64_t nbits_; + std::string packing_; +}; + +Status DequantizeLinearBlockWise::ComputeGPTQ(OpKernelContext* ctx) const { + const int in_features = K_; + const int out_features = N_; + const int groupsize = block_size_; + const auto* input_qweight = ctx->Input(0); + const auto* input_scale = ctx->Input(1); + const auto* input_zeros = ctx->Input(2); + const auto* input_gidx = ctx->Input(3); + const auto& weight_shape = input_qweight->Shape(); + + auto OutputShape = TensorShape({in_features, out_features}); + + Tensor* Y = ctx->Output(0, OutputShape); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + auto fp16_weight_shape = weight_shape; + fp16_weight_shape[0] = in_features; + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const auto* zero_point = input_zeros && input_zeros->DataRaw() ? input_zeros->DataRaw() : nullptr; + if (nbits_ != 4) { + GPTQPacking::GeneralDequant(thread_pool, input_qweight->Data(), + input_scale->Data(), + static_cast(zero_point), + input_gidx->Data(), + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + } else if (input_gidx && input_gidx->Shape().Size() > 1) { + GPTQPacking::DequantWeightNbitGidx(thread_pool, input_qweight->Data(), + input_scale->Data(), + static_cast(zero_point), + input_gidx->Data(), + Y->MutableData(), + in_features, weight_shape[1], nbits_); + } else { + GPTQPacking::DequantWeightNbit(thread_pool, input_qweight->Data(), + input_scale->Data(), + static_cast(zero_point), + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + } + return Status::OK(); +} + +Status DequantizeLinearBlockWise::ComputeHQQ(OpKernelContext* ctx) const { + const int in_features = K_; + // const int out_features = N_; + const int groupsize = block_size_; + const auto* input_qweight = ctx->Input(0); + const auto* input_scale = ctx->Input(1); + const auto* input_zeros = ctx->Input(2); + const auto& weight_shape = input_qweight->Shape(); + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + auto OutputShape = TensorShape({in_features, N_}); + + Tensor* Y = ctx->Output(0, OutputShape); + + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + auto fp16_weight_shape = weight_shape; + fp16_weight_shape[0] = in_features; + if (nbits_ != 4) { + GPTQPacking::GeneralDequant(thread_pool, input_qweight->Data(), + input_scale->Data(), + input_zeros->Data(), + nullptr, + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + } else{ + GPTQPacking::DequantWeightNbit(thread_pool, input_qweight->Data(), + input_scale->Data(), + input_zeros->Data(), + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + } + return Status::OK(); +} + +Status DequantizeLinearBlockWise::Compute(OpKernelContext* ctx) const { + if (packing_ == "gptq") { + return this->ComputeGPTQ(ctx); + } + if (packing_ == "hqq") { + return this->ComputeHQQ(ctx); + } + const Tensor* b = ctx->Input(0); + const Tensor* scales = ctx->Input(1); + const Tensor* zero_points = ctx->Input(2); + + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + TensorShape b_shape({N_, K_}); + + Tensor* Y = ctx->Output(0, b_shape); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + MlasDequantizeBlockwise( + Y->MutableData(), // dequantized output + blob_data, // quantized input + scales_data, // quantization scales + zero_points_data, // quantization zero points + static_cast(block_size_), // quantization block size + axis_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + DequantizeLinearBlockWise, + kMSDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DequantizeLinearBlockWise); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise_imp.cc b/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise_imp.cc new file mode 100644 index 0000000000000..49f00c42cdb0b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise_imp.cc @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define DequantizeLinearBlockWise operator, it is basically +// dequantize input tensor and unpack it into float/half tensor. +// + +#include +#include +#include +#include +#include +#include +#include "core/common/common.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { +namespace GPTQPacking { +static std::unique_ptr l_g_idx = std::make_unique(4096*128); + +template +void GeneralDequant(concurrency::ThreadPool* pool, const int32_t* qweight_i32, const float* scale, + const ZERO_TYPE* qzeros_i32, const int32_t* g_idx, + float* output, uint32_t mat_k, uint32_t mat_n, int bits, int group_size) { + const int32_t* group_idx = g_idx; + if (g_idx == nullptr) { + int32_t* sg_idx = l_g_idx.get(); + for (uint32_t i = 0; i < mat_k; i++) { + sg_idx[i] = i / group_size; + } + group_idx = sg_idx; + } + const uint32_t max_num_in_bits = (1 << bits) - 1; + const uint32_t qzeros_size_n = (mat_n * bits + 31) / 32; + constexpr int32_t kThreadBlockSize = 64; + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(mat_k / kThreadBlockSize), + [&](std::ptrdiff_t row_block) { + row_block*=kThreadBlockSize; + //for (uint32_t row_block = 0; row_block < mat_k; row_block+=kThreadBlockSize) { + for (uint32_t sub_row = 0; sub_row < mat_k; sub_row += 1) { + uint32_t row = row_block + sub_row; + for (uint32_t col = 0; col < mat_n; col++) { + uint32_t reorder_idx = group_idx[row]; + + uint32_t weight_int = 0; + uint8_t wv1 = 0; + + int start_bits = row * bits; + int first = start_bits / 32; + int end_bits = (start_bits + bits); + int second = end_bits / 32; + start_bits = start_bits % 32; + end_bits = end_bits % 32; + + weight_int = qweight_i32[first * mat_n + col]; + wv1 = (weight_int >> start_bits) & max_num_in_bits; + if (first != second) { + weight_int = qweight_i32[second * mat_n + col]; + wv1 |= (weight_int & ((1 << end_bits) - 1)) << (32 - start_bits); + } + + float f_zeros = 8; + if constexpr (std::is_same_v) { + f_zeros = qzeros_i32[reorder_idx * mat_n + col]; + }else{ + uint32_t zero_v1 = 0x88888888; + uint8_t zv1 = 8; + if (qzeros_i32 != nullptr) { + int start_bits = col * bits; + int first = start_bits / 32; + int end_bits = (start_bits + bits); + int second = end_bits / 32; + start_bits = start_bits % 32; + end_bits = end_bits % 32; + + zero_v1 = qzeros_i32[reorder_idx * qzeros_size_n + first]; + zv1 = (zero_v1 >> start_bits) & max_num_in_bits; + if (first != second) { + zero_v1 = qzeros_i32[reorder_idx * qzeros_size_n + second]; + zv1 |= (zero_v1 & ((1 << end_bits) - 1)) << (32 - start_bits); + } + f_zeros = float(zv1); + } + } + + float out_real =(float(wv1)-f_zeros) * scale[reorder_idx * mat_n + col]; + if (fabs(output[row * mat_n + col] - out_real) > 0.001) { + printf("error %f %f\n", output[row * mat_n + col], out_real); + } + } + } + //} + }); +} + +template void GeneralDequant(concurrency::ThreadPool* pool, const int32_t* qweight_i32, const float* scale, + const float* qzeros_i32, const int32_t* g_idx, + float* output, uint32_t mat_k, uint32_t mat_n, int bits, int group_size); +template void GeneralDequant(concurrency::ThreadPool* pool, const int32_t* qweight_i32, const float* scale, + const uint32_t* qzeros_i32, const int32_t* g_idx, + float* output, uint32_t mat_k, uint32_t mat_n, int bits, int group_size); +template +void DequantWeightNbitGidx(concurrency::ThreadPool* pool, + const int32_t* qweight_i32, const SCALE_TYPE* scale, + const uint32_t* qzeros_i32, const int32_t* g_dix, + SCALE_TYPE* output, + uint32_t mat_k, uint32_t mat_n, int bits) { + assert(bits == 4); + constexpr uint32_t kWBITS = 4; + constexpr uint32_t kOUTCOLBLOCK = 64; + const uint32_t kCompressedSize = 32 / 4; + constexpr uint32_t kOUTCOLBLOCK_PER_GROUP = kOUTCOLBLOCK / kCompressedSize; + constexpr uint32_t kThreadBlockSize = 64; + const int qzeros_size_n = (mat_n + kCompressedSize - 1) / kCompressedSize; + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(mat_k / kThreadBlockSize), + [&](std::ptrdiff_t row_block) { + //for (uint32_t row_block = 0; row_block < mat_k; row_block += kThreadBlockSize) { + for (uint32_t sub_row = 0; sub_row < kThreadBlockSize; sub_row += kCompressedSize) { + uint32_t zeros_u32[64]; + for (uint32_t col_block = 0; col_block < mat_n; col_block += kOUTCOLBLOCK) { + for (uint32_t i = 0; i < kOUTCOLBLOCK_PER_GROUP; i++) { + std::copy_n(qzeros_i32 + g_dix[row_block + sub_row + i] * qzeros_size_n + col_block / kCompressedSize, + kOUTCOLBLOCK_PER_GROUP, zeros_u32 + i * kOUTCOLBLOCK_PER_GROUP); + } + for (uint32_t sub_col_block = 0; sub_col_block < kOUTCOLBLOCK_PER_GROUP; sub_col_block++) { + uint32_t qweight_g8[kCompressedSize]; + std::copy_n(qweight_i32 + (row_block + sub_row) / kCompressedSize * mat_n + + col_block + sub_col_block * kOUTCOLBLOCK_PER_GROUP, kCompressedSize, qweight_g8); + for (uint32_t row_sub_idx = 0; row_sub_idx < kCompressedSize; row_sub_idx++) { + int reorder_idx = g_dix[row_block + sub_row + row_sub_idx]; + for (uint32_t sub_idx = 0; sub_idx < kOUTCOLBLOCK_PER_GROUP; sub_idx++) { + int col_idx = col_block + sub_idx + sub_col_block * kOUTCOLBLOCK_PER_GROUP; + auto qweight = int32_t((qweight_g8[sub_idx] >> (row_sub_idx * kWBITS)) & 0xf); + int32_t qzeros = int32_t(0xf & (zeros_u32[sub_idx + row_sub_idx * kOUTCOLBLOCK_PER_GROUP] >> (row_sub_idx * kWBITS))); + output[(row_block + sub_row + row_sub_idx) * mat_n + col_idx] = float(qweight - qzeros) * scale[reorder_idx * mat_n + col_idx]; + } + } + } + } + } + //} + }); +} + +template void +DequantWeightNbitGidx(concurrency::ThreadPool* pool, + const int32_t* qweight_i32, const float* scale, + const uint32_t* qzeros_i32, const int32_t* g_dix, + float* output, + uint32_t mat_k, uint32_t mat_n, int bits); + + +template +void DequantWeightNbit( + concurrency::ThreadPool* pool, + const int32_t* qweight_i32, + const SCALE_TYPE* scale, + const ZEROT* qzeros_i32, + SCALE_TYPE* output, + uint32_t mat_k, + uint32_t mat_n, + uint32_t bits, + uint32_t groupsize) { + assert(bits == 4); + constexpr uint32_t kWBITS = 4; + constexpr uint32_t kOUTCOLBLOCK = 64; + const uint32_t kCompressedSize = 32 / 4; + const int qzeros_size_n = (mat_n + kCompressedSize - 1) / kCompressedSize; + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(mat_k/groupsize), + [&](std::ptrdiff_t row_block) { + row_block *= groupsize; + //for (uint32_t row_block = 0; row_block < mat_k; row_block += groupsize) { + for (uint32_t col_block = 0; col_block < mat_n; col_block += kOUTCOLBLOCK) { + for (uint32_t sub_col_block = 0; sub_col_block < kOUTCOLBLOCK; sub_col_block += kCompressedSize) { + for (uint32_t inner_group_idx = 0; inner_group_idx < groupsize; inner_group_idx += kCompressedSize) { + uint32_t qweight_g8[kCompressedSize]; + std::copy_n(qweight_i32 + (row_block + inner_group_idx) / kCompressedSize * mat_n + col_block + sub_col_block, kCompressedSize, qweight_g8); + uint32_t zeros_u32 = qzeros_i32[row_block / groupsize * qzeros_size_n + (col_block + sub_col_block) / kCompressedSize]; + uint8_t zeros_u8_8[kCompressedSize]; + for (uint32_t i = 0; i < kCompressedSize; i++) { + zeros_u8_8[i] = (zeros_u32 >> (i * kWBITS)) & 0xf; + } + for (uint32_t row_sub_idx = 0; row_sub_idx < kCompressedSize; row_sub_idx++) { + const SCALE_TYPE* scale_p = &scale[row_block / groupsize * mat_n + col_block + sub_col_block]; + for (uint32_t sub_col_idx = 0; sub_col_idx < kCompressedSize; sub_col_idx++) { + int col_idx = col_block + sub_col_block + sub_col_idx; + auto qweight = int32_t((qweight_g8[sub_col_idx] >> (row_sub_idx * kWBITS)) & 0xf); + output[(row_block + inner_group_idx + row_sub_idx) * mat_n + col_idx] = float(qweight - int32_t(zeros_u8_8[sub_col_idx])) * (*scale_p++); + } + } + } + } + } + //} + }); +} + + + +template void DequantWeightNbit( + concurrency::ThreadPool* pool, + const int32_t* qweight_i32, + const float* scale, + const float* qzeros_i32, + float* output, + uint32_t mat_k, + uint32_t mat_n, + uint32_t bits, + uint32_t groupsize); + +template void DequantWeightNbit( + concurrency::ThreadPool* pool, + const int32_t* qweight_i32, + const float* scale, + const uint32_t* qzeros_i32, + float* output, + uint32_t mat_k, + uint32_t mat_n, + uint32_t bits, + uint32_t groupsize); +} // namespace GPTQPacking +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise_imp.h b/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise_imp.h new file mode 100644 index 0000000000000..ce45ff03c438c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantizeLinear_blockwise_imp.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define DequantizeLinearBlockWise operator, it is basically +// dequantize input tensor and unpack it into float/half tensor. +// + +#pragma once + +#include "core/mlas/inc/mlas_q4.h" +namespace onnxruntime { +namespace contrib { +namespace GPTQPacking { + +template +void DequantWeightNbitGidx(concurrency::ThreadPool* pool, + const int32_t* qweight_i32_i, const SCALE_TYPE* scale_fp16, + const uint32_t* qzeros_i32_i, const int32_t* g_dix, + SCALE_TYPE* b_fp16, + uint32_t mat_k, uint32_t mat_n, int bits); + +template +void DequantWeightNbit( + concurrency::ThreadPool* pool, + const int32_t* qweight_i32, + const SCALE_TYPE* scale, + const ZEROT* qzeros_i32, + SCALE_TYPE* output, + uint32_t mat_k, + uint32_t mat_n, + uint32_t bits, + uint32_t groupsize); +template +void GeneralDequant(concurrency::ThreadPool* pool, const int32_t* qweight_i32, const float* scale, + const ZERO_TYPE* qzeros_i32, const int32_t* g_idx, + float* output, uint32_t mat_k, uint32_t mat_n, int bits, int group_size); + +} // namespace GPTQPacking +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 34b44694a5fcc..776b91992b6a0 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -122,6 +122,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DequantizeLinearBlockWise); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DequantizeLinearBlockWise); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4); @@ -323,6 +325,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantizeLinear_blockwise.cc b/onnxruntime/contrib_ops/cuda/quantization/dequantizeLinear_blockwise.cc new file mode 100644 index 0000000000000..c7127e6558aa6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantizeLinear_blockwise.cc @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define DequantizeLinearBlockWise operator, it is basically +// dequantize input tensor and unpack it into float/half tensor. +// +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +#include "dequantize_blockwise.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class DequantizeLinearBlockWise final : public CudaKernel { + public: + DequantizeLinearBlockWise(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(K_ * N_ > 0, "K and N must be greater than 0."); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("axis", &axis_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("packing", &packing_)); + if (packing_ == "default") { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for DequantizeLinearBlockWise op," + " additional bits support is planned."); + } else if (packing_ == "gptq") { + ORT_ENFORCE(axis_ == 0, "axis_ should be 0 for gptq packing."); + ORT_ENFORCE(nbits_ > 1 && nbits_ < 8, "nbits_ should be in range of 2-8."); + } else if (packing_ == "hqq") { + ORT_ENFORCE(axis_ == 0, "axis_ should be 0 for hqq packing."); + ORT_ENFORCE(nbits_ == 4, "nbits_ should be in range of 2-8."); + } else { + ORT_THROW("Unsupported packing type: ", packing_); + } + } + + Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternalGPTQ(OpKernelContext* context) const; + Status ComputeInternalHQQ(OpKernelContext* context) const; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t axis_; + int64_t nbits_; + std::string packing_; +}; + +template +Status DequantizeLinearBlockWise::ComputeInternalGPTQ(OpKernelContext* ctx) const { + const int in_features = K_; + const int out_features = N_; + const int groupsize = block_size_; + const auto* input_qweight = ctx->Input(0); + const auto* input_scale = ctx->Input(1); + const auto* input_zeros = ctx->Input(2); + const auto* input_gidx = ctx->Input(3); + const auto& weight_shape = input_qweight->Shape(); + + auto OutputShape = TensorShape({in_features, out_features}); + + Tensor* Y = ctx->Output(0, OutputShape); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + auto fp16_weight_shape = weight_shape; + fp16_weight_shape[0] = in_features; + + const auto* zero_point = input_zeros && input_zeros->DataRaw() ? input_zeros->DataRaw() : nullptr; + if (input_gidx && input_gidx->Shape().Size() > 1) { + GPTQPacking::DequantWeightNbitGidx(Stream(ctx), input_qweight->Data(), + input_scale->Data(), + static_cast(zero_point), + input_gidx->Data(), + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + } else { + GPTQPacking::DequantWeightNbit(Stream(ctx), input_qweight->Data(), + input_scale->Data(), + static_cast(zero_point), + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + } + return Status::OK(); +} + +template +Status DequantizeLinearBlockWise::ComputeInternalHQQ(OpKernelContext* ctx) const { + const int in_features = K_; + // const int out_features = N_; + const int groupsize = block_size_; + const auto* input_qweight = ctx->Input(0); + const auto* input_scale = ctx->Input(1); + const auto* input_zeros = ctx->Input(2); + const auto& weight_shape = input_qweight->Shape(); + typedef typename ToCudaType::MappedType CudaT; + auto OutputShape = TensorShape({in_features, N_}); + + Tensor* Y = ctx->Output(0, OutputShape); + + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + auto fp16_weight_shape = weight_shape; + fp16_weight_shape[0] = in_features; + + GPTQPacking::DequantWeightNbit(Stream(ctx), input_qweight->Data(), + input_scale->Data(), + static_cast(input_zeros->DataRaw()), + Y->MutableData(), + in_features, weight_shape[1], nbits_, groupsize); + + return Status::OK(); +} + +template +Status DequantizeLinearBlockWise::ComputeInternal(OpKernelContext* ctx) const { + if (packing_ == "gptq") { + return this->ComputeInternalGPTQ(ctx); + } + if (packing_ == "hqq") { + return this->ComputeInternalHQQ(ctx); + } + const Tensor* b = ctx->Input(0); + const Tensor* scales = ctx->Input(1); + const Tensor* zero_points = ctx->Input(2); + + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + TensorShape b_shape({N_, K_}); + + Tensor* Y = ctx->Output(0, b_shape); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + + if (axis_ == 1) { + // column-wise block + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(Y->MutableDataRaw()), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(Y->MutableDataRaw()), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(block_size_), + axis_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + +#if 0 + cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); + T* b_data_cpu = new T[K_ * N_]; + cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); + delete[] b_data_cpu; +#endif + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DequantizeLinearBlockWise, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinearBlockWise); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DequantizeLinearBlockWise, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + DequantizeLinearBlockWise); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 0b8d0cbd8616a..6ff296d5fe570 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -353,7 +353,7 @@ template bool TryMatMul4Bits( namespace GPTQPacking { constexpr int kBlockSize = 256; constexpr int kNumWaves = 32; -const int width_element_per_block = 32 * 2; +constexpr int kWidthPerBlock = 32 * 2; template __device__ __forceinline__ float warpReduceSum(float sum) { if (WarpSize >= 32) @@ -383,7 +383,7 @@ __global__ void MatMulW4A16Kernel(T* out, const T* inA, const uint32_t* inB, con const half2* inA_start = (const half2*)(inA + blockIdx.y * matrix_k + y_start); - int n_offset_x = bid * width_element_per_block + threadIdx.x * 2; + int n_offset_x = bid * kWidthPerBlock + threadIdx.x * 2; int start_group_id = (y_start / groupsize); int compressed_idx = threadIdx.x % 4; @@ -484,7 +484,7 @@ __global__ void MatMulW4A16Kernel(T* out, const T* inA, const uint32_t* inB, con __syncthreads(); sum[i] = warpReduceSum<32>(sum[i]); if (threadIdx.x == 0) { - out[+blockIdx.y * matrix_N + bid * width_element_per_block + + out[+blockIdx.y * matrix_N + bid * kWidthPerBlock + threadIdx.y * 2 + i] = __float2half_rn(sum[i]); } } @@ -592,7 +592,7 @@ void TryMatMul4Bits( uint32_t groupsize) { const int block_k = ((matrix_k + 31) / 32 + 7) / 8 * 8; - dim3 gridDim = {(matrix_N + width_element_per_block - 1) / width_element_per_block, matrix_M}; + dim3 gridDim = {(matrix_N + kWidthPerBlock - 1) / kWidthPerBlock, matrix_M}; dim3 blockDim = {32, (matrix_k + block_k - 1) / block_k}; MatMulW4A16Kernel<<>>( static_cast(mul_out_data), static_cast(input_data), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e8ff8e8f519a6..f9fc95adee73b 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3330,7 +3330,107 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::UINT32); } }); + static const char* DequantizeLinearBlockWise_ver1_doc = R"DOC( +DequantizeLinearBlockWise de-quantize a input tensor(uint8/int8/uint32/int32) from blockwise quantization with N bits(e.g., 2, 3, 4, 5, 6, 7) packing to a normal(float32/Half) tensor. +It has two or three inputs, quantized input tensor, scale tensor and zero_point tensor(optional). + 1. Input is a 2D Tensor. Its original input shape are specified by attribute 'K' and 'N'. + 2. Packing can along with dimension 0 (e.g. row) or dimension 1 (e.g. column). which can be specified by attribute 'axis'. + 3. bits and group-size are specified by attribute 'bits' and 'block_size'. + 4. zero_point is optional. If it is not specified, it means symmetric quantization. zero_point could be type same as input or same as scale, which means it's not packed. +It support different packing style, ["default", "gptq", "hqq", "exl2", "awq"], Please note that "exl2", "awq" is not supported yet. +with "default", we have + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] + Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + +for bellow two packing method, bits are squeezed into uint32_t, and each uint32_t contains 32/bits quantized values. +basiclly, you can packing it with +``` +def pack_on_row_fast_anybit(pack_tensor, ori_int_tensor, bits): + need_transpose = False + if pack_tensor.shape[0] != ori_int_tensor.shape[0]: + need_transpose = True + ori_int_tensor = ori_int_tensor.T + pack_tensor.mul_(0) + wf = torch.arange(0, bits).to(pack_tensor.device).view(1, 1, -1) + out = torch.bitwise_right_shift(ori_int_tensor.unsqueeze(-1), wf) + torch.bitwise_and(out, 1, out=out) + out = out.reshape(ori_int_tensor.shape[0], -1, 32) + wf1 = torch.arange(0, 32, 1).to(pack_tensor.device).view(1, 1, -1) + out = torch.bitwise_left_shift(out, wf1) + out = out.sum(dim=-1).int() + + if need_transpose: + out = out.T.contiguous() + pack_tensor.copy_(out) +``` +with "gptq", we have: + 32 = bits_of(int32) + Input is stored as uint32_t with shape: [ceil(K*bits, 32), N]: + zero_points has shape uint32_t [K//block_size, ceil(N*bits, 32)] + scale has shape float16[K//block_size, K] +with "hqq", we have: + Input is stored as uint32_t with shape: [ceil(K*bits, 32), N]: + zero_points has shape float16 [K//block_size, N] + scale has shape [K//block_size, N] +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(DequantizeLinearBlockWise) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(DequantizeLinearBlockWise_ver1_doc) + .Attr("K", "original input feature", AttributeProto::INT) + .Attr("N", "original output feature", AttributeProto::INT) + .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("packing", "decribe how the quantized weight is packed, [default,gptq,hqq]", AttributeProto::STRING, std::string("default")) + .Attr("accuracy_level", + "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) " + "(default unset). It is used to control how input A is quantized or downcast internally while " + "doing computation, for example: 0 means input A will not be quantized or downcast while doing " + "computation. 4 means input A can be quantized with the same block_size to int8 internally from " + "type T1.", + AttributeProto::INT, static_cast(0)) + .Input(0, "A", "The input tensor, not quantized", "T2") + .Input(2, "scales", "quantization scale", "T1") + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx for gptq", "T2", OpSchema::Optional) + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)", "tensor(float16)"}, "Constrain quantized zero point types to uint8/uint32/int32/float16.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + ONNX_NAMESPACE::TensorShapeProto resultShape; + + std::string packing = getAttribute(ctx, "packing", "default"); + if (packing == "packing") { + resultShape.add_dim()->set_dim_value(out_features); + resultShape.add_dim()->set_dim_value(in_features); + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; + } else { + resultShape.add_dim()->set_dim_value(in_features); + resultShape.add_dim()->set_dim_value(out_features); + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; + } + }); static const char* MatMulNBits_ver1_doc = R"DOC( MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. @@ -3340,7 +3440,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 It support different packing style, ["default", "gptq", "hqq", "exl2", "awq"], Please note that "exl2", "awq" is not supported yet. with "default", we have - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - n_blocks_per_col = (K + block_size - 1) / block_size - blob_size = block_size / 8 * bits @@ -3357,13 +3457,14 @@ with "default", we have - [N * n_blocks_per_col] if bits > 4 with "gptq", we have: - Input B is stored as uint32_t with shape: [ceil(N, 32//bits), K]: - zero_points has shape uint32_t [N//block_size, K*bits//32] - scale has shape float16[N//block_size, K] + 32 = bits_of(int32) + Input B is stored as uint32_t with shape: [ceil(K*bits, 32), N]: + zero_points has shape uint32_t [K//block_size, ceil(N*bits, 32)] + scale has shape float16[K//block_size, N] with "hqq", we have: - Input B is stored as uint32_t with shape: [ceil(N, 32//bits), K]: - zero_points has shape float16 [N//block_size, K] - scale has shape [N//block_size, K] + Input B is stored as uint32_t with shape: [ceil(K*bits, 32), N]: + zero_points has shape float16 [K//block_size, N] + scale has shape [K//block_size, N] )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 35e8f05889ddf..3a876cc44ff20 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -15,6 +15,7 @@ import numpy as np import numpy.typing as npt import onnx +import torch from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version @@ -131,10 +132,13 @@ def __init__( self, block_size: int = 128, is_symmetric: bool = False, + accuracy_level: int | None = None, ): - super().__init__(algorithm="HQQ") + super().__init__(algorithm="DEFAULT") self.block_size = block_size self.is_symmetric = is_symmetric + self.bits = 4 + self.accuracy_level = accuracy_level def is_divisible(val1, val2): @@ -151,9 +155,9 @@ def __init__( # Proximal solver || weight - dequantize(quantize(weight))||_p^p @staticmethod def optimize_weights( - tensor: np.ndarray, - scale: np.ndarray, - zero: np.ndarray, + tensor: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor, min_max: list[int], axis: int = 0, opt_params: dict = None, # noqa: RUF013 @@ -167,29 +171,32 @@ def optimize_weights( opt_params["iters"], ) - w_f = tensor.astype(np.float32) - scale = scale.astype(np.float32) - zero = zero.astype(np.float32) + dtype = torch.float16 if tensor.is_cuda else torch.float32 + w_f = tensor.to(dtype) + scale = scale.to(dtype) + zero = zero.to(dtype) if lp_norm == 1: def shrink_op(x, beta): - return np.sign(x) * np.maximum(np.abs(x) - 1.0 / beta, 0) + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) else: def shrink_op(x, beta, p=lp_norm): - return np.sign(x) * np.maximum(np.abs(x) - (1.0 / beta) * np.power(np.abs(x) + 1e-8, p - 1), 0) + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1) + ) best_error = 1e4 for i in range(iters): - w_q = np.round(w_f * scale + zero).clip(min_max[0], min_max[1]) + w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1]) w_r = (w_q - zero) / scale w_e = shrink_op(w_f - w_r, beta) - zero = np.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdims=True) + zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True) beta *= kappa - current_error = float(np.abs(w_f - w_r).mean()) + current_error = float(torch.abs(w_f - w_r).mean()) if verbose: print(i, np.round(current_error, 6)) if current_error < best_error: @@ -206,30 +213,26 @@ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): if pack_tensor.shape[0] == ori_int_tensor.shape[0]: ori_int_tensor = ori_int_tensor.T pack_tensor = pack_tensor.T - compress_ratio = 32 // bits - i = 0 - row = 0 - while row < pack_tensor.shape[0]: - if bits in [2, 4, 8]: - for j in range(i, i + compress_ratio): - pack_tensor[row:] |= (ori_int_tensor[j::compress_ratio] << (bits * (j - i))).astype(np.uint32) - break - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") + if bits in [2, 4, 8]: + compress_ratio = 32 // bits + for j in range(0, compress_ratio): + pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") # from Official implementation of Half-Quadratic Quantization (HQQ) def quantize_internal( self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 ): if group_size is not None: - assert is_divisible(np.size(tensor), group_size), ( + assert is_divisible(tensor.numel(), group_size), ( "group_size should be divisble by the total tensor dimensions. shape: " + str(tensor.shape) + ", group_size: " + str(group_size) ) - weight = tensor.astype(np.float32) + weight = tensor.float() shape = weight.shape # Reshape for grouping @@ -241,8 +244,8 @@ def quantize_internal( _min, _max = weight.min(), weight.max() optimize = False else: - _min = weight.min(axis=axis, keepdims=True) # [0] - _max = weight.max(axis=axis, keepdims=True) # [0] + _min = weight.min(axis=axis, keepdim=True)[0] + _max = weight.max(axis=axis, keepdim=True)[0] max_v = 2**bits - 1 min_v = 0 @@ -250,11 +253,11 @@ def quantize_internal( # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on. # clamp to avoid half-precision problems - scale = (max_v / (_max - _min)).clip(max=2e4) + scale = (max_v / (_max - _min)).clamp(max=2e4) zero = -_min * scale if round_zero: - zero = np.round(zero) + zero = torch.round(zero) # Fine-tune weights if optimize: @@ -262,10 +265,10 @@ def quantize_internal( # Quantize # Necessary for fake quantization backprop - w_q = np.round(weight * scale + zero).clip(min_max[0], min_max[1]) - w_q = w_q.reshape(shape).astype(np.uint32) + w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1]) + w_q = w_q.reshape(shape).int() - scale = np.reciprocal(scale) + scale = 1.0 / scale if axis == 1: scale = scale.reshape(shape[0], -1) zero = zero.reshape(shape[0], -1) @@ -275,50 +278,58 @@ def quantize_internal( # cleanup del weight, _min, _max - return w_q.T, scale.T.astype(tensor.dtype), zero.T.astype(tensor.dtype) + return w_q.T, scale.T.to(tensor.dtype), zero.T.to(tensor.dtype) def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" - if node.op_type != "MatMul": return node # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") inputB = node.input[1] # noqa: N806 - B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: + b_pb, bs_graph = get_initializer(inputB, graph_stack) + if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return node # only care about constant weight - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: + b_array = onnx.numpy_helper.to_array(b_pb) + if len(b_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") return node # can only process 2-D matrix + b_array_torch = torch.from_numpy(b_array) + if torch.cuda.is_available(): + b_array_torch = b_array_torch.cuda() + quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal( + b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size + ) - quant_weight, scales, zero_points = self.quantize_internal( - B_array.T, bits=self.config.bits, group_size=self.config.block_size + packed_torch = torch.zeros( + (quant_weight_torch.shape[0] // 8, quant_weight_torch.shape[1]), + dtype=torch.int32, + device=quant_weight_torch.device, ) - packed = np.zeros((quant_weight.shape[0] // 8, quant_weight.shape[1]), dtype="uint32") - self.pack_on_row_fast_248bit(packed, quant_weight, self.config.bits) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" - for input in Bs_graph.input: + self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits) + scales = scales_torch.cpu().numpy() + zero_points = zero_points_torch.cpu().numpy() + b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) + b_quant.name = b_pb.name + "_Q4" + for input in bs_graph.input: if input.name == inputB: - Bs_graph.input.remove(input) + bs_graph.input.remove(input) break scales_tensor = onnx.numpy_helper.from_array(scales) - scales_tensor.name = B.name + "_scales" - Bs_graph.initializer.extend([B_quant, scales_tensor]) + scales_tensor.name = b_pb.name + "_scales" + bs_graph.initializer.extend([b_quant, scales_tensor]) - input_names = [node.input[0], B_quant.name, scales_tensor.name] + input_names = [node.input[0], b_quant.name, scales_tensor.name] zp_tensor = onnx.numpy_helper.from_array(zero_points) - zp_tensor.name = B.name + "_zero_points" - Bs_graph.initializer.extend([zp_tensor]) + zp_tensor.name = b_pb.name + "_zero_points" + bs_graph.initializer.extend([zp_tensor]) input_names.append(zp_tensor.name) kwargs = {} - rows, cols = B_array.shape + rows, cols = b_array.shape kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = self.config.bits @@ -406,7 +417,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: Bs_graph.initializer.extend([B_quant, scales_tensor]) input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.is_symmetric: + if not self.config.is_symmetric: zp_tensor = onnx.numpy_helper.from_array(zero_points) zp_tensor.name = B.name + "_zero_points" Bs_graph.initializer.extend([zp_tensor]) @@ -417,8 +428,8 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 - kwargs["block_size"] = self.block_size - if self.accuracy_level is not None: + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( @@ -456,6 +467,11 @@ def __init__( self.accuracy_level = accuracy_level self.nodes_to_exclude = set(nodes_to_exclude) self.algo_config = algo_config + self.node_quantizer = None + if algo_config.algorithm == "HQQ": + self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) + elif algo_config.algorithm == "DEFAULT": + self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] @@ -492,10 +508,9 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") out_node = node elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": - hqq_quantizer = HQQWeightOnlyQuantizer(self.algo_config) - out_node = hqq_quantizer.quantize(node, graph_stack) + out_node = self.node_quantizer.quantize(node, graph_stack) else: - out_node = DefaultWeightOnlyQuantizer(self.algo_config).quantize(node, graph_stack) + out_node = self.node_quantizer.quantize(node, graph_stack) new_nodes.append(out_node) graph.ClearField("node") @@ -565,7 +580,7 @@ def inc_dataloader(): logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): - if self.algo_config is None or self.algo_config.algorithm == "HQQ": + if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] opset_import = self.model.opset_import() @@ -576,7 +591,6 @@ def process(self): has_ms_domain = True if not has_ms_domain: opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - self._process_subgraph(graph_stack) self.model.clean_initializers() else: @@ -664,7 +678,9 @@ def parse_args(): if args.quant_method == "hqq": quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) elif args.quant_method == "default": - quant_config = DefaultWeightOnlyQuantConfig(block_size=args.block_size, is_symmetric=args.symmetric) + quant_config = DefaultWeightOnlyQuantConfig( + block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + ) else: quant_config = None