From b95aa0563ff99cfcda3f036fe11d8aeecb3539a1 Mon Sep 17 00:00:00 2001 From: duanshengliu <44742794+duanshengliu@users.noreply.github.com> Date: Wed, 7 Aug 2024 07:23:20 +0800 Subject: [PATCH] Improve speed in combining per-channel data (#21563) ### Description Improve speed in combining `per-channel` data for using a single `np.concatenate` instead of multiple `np.concatenates` within a for loop. ### Motivation and Context Fix the issue https://github.com/microsoft/onnxruntime/issues/21562 Signed-off-by: duansheng.liu <44742794+duanshengliu@users.noreply.github.com> --- .../python/tools/quantization/base_quantizer.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 2f197cc7f31c0..aab04485246d6 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -418,6 +418,9 @@ def quantize_weight_per_channel_impl( zero_point_list = [] scale_list = [] quantized_per_channel_data_list = [] + weights_shape = list(weights.shape) + reshape_dims = list(weights_shape) # deep copy + reshape_dims[channel_axis] = 1 # only one per channel for reshape for i in range(channel_count): per_channel_data = weights.take(i, channel_axis) channel_override_index = i if i < num_channel_overrides else 0 @@ -460,17 +463,10 @@ def quantize_weight_per_channel_impl( zero_point_list.append(zero_point) scale_list.append(scale) - quantized_per_channel_data_list.append(quantized_per_channel_data) + quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims)) # combine per_channel_data into one - weights_shape = list(weights.shape) - reshape_dims = list(weights_shape) # deep copy - reshape_dims[channel_axis] = 1 # only one per channel for reshape - quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims) - for i in range(1, len(quantized_per_channel_data_list)): - channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims) - quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis) - + quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis) q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale"