diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h index 455623d602583..747f959e2e217 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h @@ -5,8 +5,11 @@ #include "blockwise_quant_block.h" +#include + #include "core/framework/float16.h" #include "core/platform/threadpool.h" +#include namespace onnxruntime { namespace contrib { @@ -15,8 +18,8 @@ template void QuantizeBlockwise( uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] const T* src, // shape: [K, N] - T* scale, // shape: [N, block_per_K] - uint8_t* zero_points, // shape: [N, block_per_K] + T* scale, // shape: [N * block_per_K] + uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2] int32_t N, int32_t K, onnxruntime::concurrency::ThreadPool* thread_pool) { @@ -24,23 +27,40 @@ void QuantizeBlockwise( reinterpret_cast*>(dst); int32_t block_per_K = (K + block_size - 1) / block_size; - int32_t task_count = N * block_per_K; + int32_t total_block_count = N * block_per_K; + + std::vector zero_points_tmp; // to avoid race condition + (void)zero_points_tmp; + uint8_t* zero_points_tmp_ptr = zero_points; + if (bits <= 4 && zero_points != nullptr) { + zero_points_tmp.resize(total_block_count, 0); + zero_points_tmp_ptr = zero_points_tmp.data(); + } concurrency::ThreadPool::TryBatchParallelFor( thread_pool, - task_count, - [&](ptrdiff_t task_idx) { - int32_t n = static_cast(task_idx / block_per_K); - int32_t k_block_idx = static_cast(task_idx % block_per_K); + total_block_count, + [&](ptrdiff_t block_idx) { + int32_t n = static_cast(block_idx / block_per_K); + int32_t k_block_idx = static_cast(block_idx % block_per_K); int32_t k = k_block_idx * block_size; - BlockwiseQuantBlock* blob_ptr = dst_blob + task_idx; - if (nullptr != zero_points) { - blob_ptr->quant(src + k * N + n, scale[task_idx], zero_points[task_idx], k, K, N); + BlockwiseQuantBlock* blob_ptr = dst_blob + block_idx; + if (nullptr != zero_points_tmp_ptr) { + blob_ptr->quant(src + k * N + n, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N); } else { - blob_ptr->quant(src + k * N + n, scale[task_idx], k, K, N); + blob_ptr->quant(src + k * N + n, scale[block_idx], k, K, N); } }, 0); + + if (bits <= 4 && zero_points != nullptr) { // compact zero points + for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) { + zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4)); + } + if (total_block_count & 1) { + zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] &0xf0) | zero_points_tmp[total_block_count - 1]; + } + } } #define QuantizeBlockwise4Bits(block_size) \ @@ -78,10 +98,10 @@ void QuantizeBlockwise( template void DequantizeBlockwise( - T* dst, // [N, K] - const uint8_t* src, // [N, block_per_K, block_blob_size] - const T* scale, // [N, block_per_K] - const uint8_t* zero_points, // [N, block_per_K] + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [N, block_per_K, block_blob_size] + const T* scale, // shape: [N, block_per_K] + const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] int32_t N, int32_t K, onnxruntime::concurrency::ThreadPool* thread_pool) { @@ -100,7 +120,14 @@ void DequantizeBlockwise( int32_t k = k_block_idx * block_size; const BlockwiseQuantBlock* blob_ptr = src_blob + task_idx; if (nullptr != zero_points) { - blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K); + // if bits >= 4 + if constexpr (bits > 4) { // zero point is stored with a byte + blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K); + } else { // zero points is stored with 4bits + uint8_t zp = zero_points[task_idx / 2]; + zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf); + blob_ptr->dequant(dst + n * K + k, scale[task_idx], zp, k, K); + } } else { blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 35e4adcfd9f87..505ed86d67f17 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -50,16 +50,19 @@ __global__ void Dequantize4BitsKernel( const T* scale_data, const uint8_t* zero_points, int block_size, - int blocks_per_tb, + int blocks_per_threadblock, int shift) { - int block_id = blockIdx.x * blocks_per_tb + ((threadIdx.x * 8)>>shift); - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1<> shift); + int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - T zero_point = static_cast(zero_points ? zero_points[block_id] : (uint8_t)(8)); + uint8_t zp = 8; + if (zero_points) { + zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f); + } output = output + element_offset; - DequantizeEightElements(quant_value, scale, zero_point, output); + DequantizeEightElements(quant_value, scale, static_cast(zp), output); } template @@ -67,16 +70,17 @@ Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] int k, int n, int block_size, cudaStream_t stream) { + // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_tb = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int k_blocks = k / block_size; - int blocks_per_grid = static_cast(CeilDiv(n * k_blocks, blocks_per_tb)); + int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int blocks_per_K = k / block_size; + int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); int shift = static_cast(log2f(float(block_size))); Dequantize4BitsKernel<<>>( @@ -85,7 +89,7 @@ Status Dequantize4Bits( scales_data, zero_points, block_size, - blocks_per_tb, + blocks_per_threadblock, shift); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu index adaec89154578..499abbceeabc3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu @@ -62,7 +62,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, constexpr int BLOCKSIZEN = 8; -template +template __global__ void MatMulFloatInt4Kernel( T* output, const T* a_data, @@ -77,7 +77,7 @@ __global__ void MatMulFloatInt4Kernel( int lane_id = threadIdx.x; int warp_id = threadIdx.y; int n_id = n_block_id * BLOCKSIZEN + warp_id; - int group_count = (k + group_size - 1) / group_size; + int blocks_per_K = (k + block_size - 1) / block_size; int thread_id = warp_id * 32 + lane_id; int k_iter = k / 256; @@ -85,31 +85,35 @@ __global__ void MatMulFloatInt4Kernel( // load scale to shared buffer T* b_scale_vec = (T*)shared_buffer; - uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + BLOCKSIZEN * group_count); - int offset = n_block_id * BLOCKSIZEN * group_count; - for (int i = thread_id; i < BLOCKSIZEN * group_count; i += 256) { + uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + BLOCKSIZEN * blocks_per_K); + int offset = n_block_id * BLOCKSIZEN * blocks_per_K; + for (int i = thread_id; i < BLOCKSIZEN * blocks_per_K; i += 256) { b_scale_vec[i] = scales_data[offset + i]; - b_zp_vec[i] = zero_points != nullptr ? zero_points[offset + i] : uint8_t(8); + } + for (int i = thread_id; i < BLOCKSIZEN * blocks_per_K / 2; i += 256) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[offset/2 + i] : uint8_t(0x88); } __syncthreads(); a_data += m_id * k; - b_data_quant += n_id * group_count * (group_size / 2); + b_data_quant += n_id * blocks_per_K * (block_size / 2); float sum = 0.f; int k_id = 0; for (; k_id < (k & 0xffffff00); k_id += 256) { uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4)); - T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; - uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; + int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; + T scale = b_scale_vec[block_idx]; + uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx/2] >> 4) : (b_zp_vec[block_idx/2] & 0x0f); sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } // handle reminder if (k_id + lane_id * 8 < k) { uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); - T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; - uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; + int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; + T scale = b_scale_vec[block_idx]; + uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx/2] >> 4) : (b_zp_vec[block_idx/2] & 0x0f); sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } @@ -133,29 +137,29 @@ bool TryMatMul4Bits( int m, int n, int k, - int group_size, + int block_size, cudaStream_t stream) { if (n % BLOCKSIZEN != 0 || k % 8 != 0 || m > 1) { return false; } dim3 blocks((n + BLOCKSIZEN - 1) / BLOCKSIZEN, m); dim3 threads(32, 8); - int shared_mem_size = (sizeof(T) + 1) * ((k + group_size - 1) / group_size * 8); + int shared_mem_size = (sizeof(T) + 1) * ((k + block_size - 1) / block_size * 8); - if (16 == group_size) { + if (16 == block_size) { MatMulFloatInt4Kernel<<>>( output, a_data, b_data_quant, scales_data, zero_points, m, n, k); - } else if (32 == group_size) { + } else if (32 == block_size) { MatMulFloatInt4Kernel<<>>( output, a_data, b_data_quant, scales_data, zero_points, m, n, k); - } else if (64 == group_size) { + } else if (64 == block_size) { MatMulFloatInt4Kernel<<>>( output, a_data, b_data_quant, scales_data, zero_points, m, n, k); - } else if (128 == group_size) { + } else if (128 == block_size) { MatMulFloatInt4Kernel<<>>( output, a_data, b_data_quant, scales_data, zero_points, m, n, k); } else { - ORT_THROW("block size ", group_size, " is not supported"); + ORT_THROW("block size ", block_size, " is not supported"); } return true; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6f75ec3ff09e0..1e3277b4a8d5e 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3026,10 +3026,10 @@ struct Blob { - shape: [n_cols, n_blocks_per_col, blob_size] - type: uint8_t scales: - - shape: [n_cols, n_blocks_per_col] + - shape: [n_cols * n_blocks_per_col] - type: float32 or float16. Same as input A zero_points - - shape: [n_cols, (n_blocks_per_col * 4 + 4) / 8] for nbits <= 4 and [n_cols, n_blocks_per_col] for nbits > 4 + - shape: [(n_cols * n_blocks_per_col + 1) / 2] for nbits <= 4 and [n_cols * n_blocks_per_col] for nbits > 4 - type: uint8_t )DOC"; @@ -3037,11 +3037,7 @@ zero_points ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) .SetDomain(kMSDomain) .SinceVersion(1) -<<<<<<< HEAD - .SetDoc(MatMulWithCompressWeight_ver1_doc) -======= .SetDoc(MatMulNBits_ver1_doc) ->>>>>>> change matmul 4bits name .Attr("K", "size of each input feature", AttributeProto::INT) .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index d8881022a4cb6..b54905ca57bc1 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -38,7 +38,7 @@ void QuantizeMatMulNBitsBlockwise( py::array_t dst, // shape: [ N, block_per_K, block_blob_size ] py::array_t src, // shape: [K, N] py::array_t scale, // shape: [N, block_per_K] - py::array_t zero_points, // shape: [N, block_per_K] + py::array_t zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] int32_t block_size, int32_t N, int32_t K, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_fp_int4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_fp_int4.cu index 0a1dd2f4978e8..1e3780137f1cb 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_fp_int4.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_fp_int4.cu @@ -50,7 +50,11 @@ struct MatrixFloatInt4Params : template class MatrixFloatInt4 : public IKernelExplorer { public: - MatrixFloatInt4(DeviceArray& output, DeviceArray& a, DeviceArray& b, DeviceArray& scales, int m, int n, int k) { + MatrixFloatInt4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& scales, + int m, int n, int k) { params_.tuning_ctx = TuningContext(); params_.stream = Stream(); params_.output_ = static_cast(output.ptr()); @@ -63,6 +67,15 @@ class MatrixFloatInt4 : public IKernelExplorer { params_.k_ = k; } + MatrixFloatInt4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& scales, + DeviceArray& zeropoints, + int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) { + params_.zero_points_ = static_cast(zeropoints.ptr()); + } + void Run() override { contrib::cuda::TryMatMul4Bits( params_.output_, @@ -83,11 +96,12 @@ class MatrixFloatInt4 : public IKernelExplorer { ParamsT params_{}; }; -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ .def("Run", &name::Run); KE_REGISTER(m) { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_fp_int4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_fp_int4.py index ada0ffd31e957..9cb937a13ff27 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_fp_int4.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_fp_int4.py @@ -18,6 +18,7 @@ def dtype_to_funcs(dtype): } return type_map[dtype] + def dtype_to_funcs_cublas(dtype): type_map = { "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), @@ -26,37 +27,54 @@ def dtype_to_funcs_cublas(dtype): return type_map[dtype] - dtypes = ["float16", "float32"] @dataclass -class MatrixFpInt4Metric(ke.BandwidthMetric): +class MatrixMulMetric(ke.BandwidthMetric): m: int n: int k: int def report(self): - return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + +@dataclass +class MatrixFpInt4Metric(MatrixMulMetric): + is_symmetric: bool + + def report(self): + return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} is_symmetric={self.is_symmetric} {self.name}" -def profile_matmul_fp_int4_func(m, n, k, dtype, func): + +def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): np.random.seed(0) output = np.random.rand(m, n).astype(dtype) a = np.random.rand(m, k).astype(dtype) - b = np.random.randint(low=0, high=127, size=(n, (k+31)//32, 16)).astype('uint8') - scales = np.random.rand(n, (k+31)//32).astype(dtype) + b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") + scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype) + zeropoints = np.random.rand((n * ((k + 31) // 32) + 1) // 2).astype(dtype) output_d = ke.DeviceArray(output) a_d = ke.DeviceArray(a) b_d = ke.DeviceArray(b) scales_d = ke.DeviceArray(scales) + zeropoints_d = ke.DeviceArray(zeropoints) f = getattr(ke, func) - my_op = f(output_d, a_d, b_d, scales_d, m, n, k) + + my_op = ( + f(output_d, a_d, b_d, scales_d, m, n, k) + if is_symmetric + else f(output_d, a_d, b_d, scales_d, zeropoints_d, m, n, k) + ) duration_ms = my_op.Profile() total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) - ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k)) + ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric)) + def profile_gemm_func(m, n, k, dtype, func): np.random.seed(0) @@ -72,13 +90,16 @@ def profile_gemm_func(m, n, k, dtype, func): duration_ms = my_op.Profile() total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) - ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k)) + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) def profile_with_args(m, n, k, dtype, sort): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_matmul_fp_int4_func(m, n, k, dtype, func) + profile_matmul_fp_int4_func(m, n, k, dtype, func, True) + + for func in dtype_to_funcs(dtype): + profile_matmul_fp_int4_func(m, n, k, dtype, func, False) for func in dtype_to_funcs_cublas(dtype): profile_gemm_func(m, n, k, dtype, func) @@ -88,8 +109,8 @@ def profile(): dims_m = [1] for dt in dtypes: for m in dims_m: - for n,k in ((4096, 4096), (4096, 12288), (12288, 4096)): - profile_with_args(m, n, k, dt, True) + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(m, n, k, dt, False) print() diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index bc3ac66439ab7..8dd5964559390 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) + class MatMul4BitsQuantizer: """Perform 4b quantization of constant MatMul weights""" @@ -59,8 +60,8 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: # block wise quantization, each block comes from a single column packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols, k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros((cols, k_blocks), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) return (packed, scales, zero_point) @@ -194,12 +195,21 @@ def parse_args(): parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument("--block_size", required=False, default=32) - parser.add_argument("--symmetric", required=False, default=True, help="Indicate whether to quantize the model symmetrically") + parser.add_argument( + "--symmetric", required=False, default=True, help="Indicate whether to quantize the model symmetrically" + ) parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") parser.set_defaults(use_external_data_format=False) - parser.add_argument("--nodes_to_exclude", nargs='+', type=str, required=False, default=[], help="Specify the nodes to be excluded from quantization with node names") + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) return parser.parse_args() @@ -223,6 +233,6 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) - quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude = args.nodes_to_exclude) + quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_with_quant_weight_test.cc b/onnxruntime/test/contrib_ops/matmul_with_quant_weight_test.cc index 026f649a92428..a44a82bd19252 100644 --- a/onnxruntime/test/contrib_ops/matmul_with_quant_weight_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_with_quant_weight_test.cc @@ -75,7 +75,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop int64_t buf_size = number_of_block * (block_size * 4 / 8); std::vector input1_vals(buf_size); std::vector scales(number_of_block); - std::vector zp(number_of_block); + std::vector zp((N *block_per_k + 1) / 2); QuantizeDequantize(input1_f_vals, input1_vals, scales, has_zeropoint ? &zp : nullptr, N, K, block_size); @@ -98,9 +98,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); - test.AddInput("scales", {N, block_per_k}, ToFloat16(scales), true); + test.AddInput("scales", {N * block_per_k}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {N, block_per_k}, zp, true); + test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -112,9 +112,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop } else { test.AddInput("A", {M, K}, input0_vals, false); test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); - test.AddInput("scales", {N, block_per_k}, scales, true); + test.AddInput("scales", {N * block_per_k}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {N, block_per_k}, zp, true); + test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); } test.AddOutput("Y", {M, N}, expected_vals); @@ -149,6 +149,43 @@ TEST(MatMulNBits, Float16) { } } } +TEST(MatMulNBits, Float16_Dequantize) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16}) { + RunTest(M, N, K, block_size, false, true); + RunTest(M, N, K, block_size, true, true); + } + } + } + } +} + +TEST(MatMulNBits, Float16_MatMul) { + for (auto M : {1}) { + for (auto N : {32}) { + for (auto K : {64}) { + for (auto block_size : {16}) { + RunTest(M, N, K, block_size, false, true); + } + } + } + } +} + +TEST(MatMulNBits, Float16_MatMul_zp) { + for (auto M : {1}) { + for (auto N : {32}) { + for (auto K : {64}) { + for (auto block_size : {16}) { + RunTest(M, N, K, block_size, true, true); + } + } + } + } +} + #endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py index 6f36a5ed7866d..409cf7bf7c483 100644 --- a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -11,18 +11,20 @@ from onnxruntime.capi._pybind_state import quantize_matmul_4bits + def dequantize_blockwise_4bits(quant_values, scale, zero_point, valid_len): blob_size = quant_values.shape[0] block_size = blob_size * 2 - + quant_float = np.zeros((block_size), dtype=scale.dtype) for b in range(blob_size): v = quant_values[b] - quant_float[2*b] = ((v&0xf) - zero_point) * scale if 2 * b < valid_len else 0.0 - quant_float[2*b + 1] = ((v>>4) - zero_point) * scale if 2*b+1 < valid_len else 0.0 + quant_float[2 * b] = ((v & 0xF) - zero_point) * scale if 2 * b < valid_len else 0.0 + quant_float[2 * b + 1] = ((v >> 4) - zero_point) * scale if 2 * b + 1 < valid_len else 0.0 return quant_float -def quantize_blockwise_4bits_ref(matrix_float:npt.ArrayLike, block_size:int, is_symmetric:bool): + +def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): if len(matrix_float.shape) != 2: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = matrix_float.shape @@ -36,20 +38,20 @@ def quantize_blockwise_4bits_ref(matrix_float:npt.ArrayLike, block_size:int, is_ matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) - zero_point = np.zeros((cols, k_blocks), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") matrix_float_padded = np.transpose(matrix_float_padded) for n in range(cols): for k_id in range(0, rows, block_size): if is_symmetric: - amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id:k_id+block_size])) + amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) scale = bmax / (-8.0) zp = 8 else: - vmin = np.min(np.float32(matrix_float_padded[n, k_id:k_id+block_size])) - vmax = np.max(np.float32(matrix_float_padded[n, k_id:k_id+block_size])) + vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) vmin = min(vmin, 0.0) vmax = max(vmax, 0.0) scale = (vmax - vmin) / ((1 << 4) - 1) @@ -59,44 +61,75 @@ def quantize_blockwise_4bits_ref(matrix_float:npt.ArrayLike, block_size:int, is_ zp = min(15, max(0, round(zero_point_fp))) reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 - scales[n, k_id // block_size] = scale - zero_point[n, k_id // block_size] = zp - - blk_int0 = np.clip(np.round(np.float32(matrix_float_padded[n, k_id:k_id+block_size:2] * reciprocal_scale + zp)), 0, 15).astype("uint8") - blk_int1 = np.clip(np.round(np.float32(matrix_float_padded[n, k_id + 1:k_id+block_size:2] * reciprocal_scale + zp)), 0, 15).astype("uint8") + block_idx = n * k_blocks + k_id // block_size + scales[block_idx] = scale + zp_pair = zero_point[block_idx // 2] + zero_point[block_idx // 2] = ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) return (packed, scales, zero_point) -def quantize_blockwise_4bits_target(matrix_float:npt.ArrayLike, block_size:int, is_symmetric:bool): + +def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): if len(matrix_float.shape) != 2: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = matrix_float.shape k_blocks = (rows + block_size - 1) // block_size packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") - scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) - zero_point = np.full((cols, k_blocks), 8, dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) return (packed, scales, zero_point) class TestQuantizeBlockwise4Bits(unittest.TestCase): def test_quantize_blockwise_4bits(self): - for rows,cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: for block_size in [16, 32, 64, 128]: for type in [np.float32, np.float16]: for is_symmetric in [True, False]: matrix_float = np.random.rand(rows, cols).astype(type) - quant_value_ref, scales_ref, zero_point_ref = quantize_blockwise_4bits_ref(matrix_float, block_size, is_symmetric) - quant_value, scales, zero_point = quantize_blockwise_4bits_target(matrix_float, block_size, is_symmetric) + quant_value_ref, scales_ref, zero_point_ref = quantize_blockwise_4bits_ref( + matrix_float, block_size, is_symmetric + ) + quant_value, scales, zero_point = quantize_blockwise_4bits_target( + matrix_float, block_size, is_symmetric + ) assert np.allclose(scales_ref, scales) assert np.allclose(zero_point_ref, zero_point) for c in range(quant_value_ref.shape[0]): for k in range(quant_value_ref.shape[1]): - assert np.allclose(dequantize_blockwise_4bits(quant_value_ref[c][k], scales_ref[c][k], zero_point_ref[c][k], min(block_size, rows - k * block_size)), - dequantize_blockwise_4bits(quant_value[c][k], scales[c][k], zero_point[c][k], min(block_size, rows - k * block_size)), - atol= 1.2 * abs(scales[c][k])) + block_idx = c * quant_value_ref.shape[1] + k + zp_idx = block_idx // 2 + assert np.allclose( + dequantize_blockwise_4bits( + quant_value_ref[c][k], + scales_ref[block_idx], + (zero_point_ref[zp_idx] >> 4) + if (block_idx & 1) + else (zero_point_ref[zp_idx] & 0x0F), + min(block_size, rows - k * block_size), + ), + dequantize_blockwise_4bits( + quant_value[c][k], + scales[block_idx], + (zero_point[zp_idx] >> 4) if (block_idx & 1) else (zero_point[zp_idx] & 0x0F), + min(block_size, rows - k * block_size), + ), + atol=1.2 * abs(scales[block_idx]), + ) if __name__ == "__main__":