diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 2f197cc7f31c0..aab04485246d6 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -418,6 +418,9 @@ def quantize_weight_per_channel_impl( zero_point_list = [] scale_list = [] quantized_per_channel_data_list = [] + weights_shape = list(weights.shape) + reshape_dims = list(weights_shape) # deep copy + reshape_dims[channel_axis] = 1 # only one per channel for reshape for i in range(channel_count): per_channel_data = weights.take(i, channel_axis) channel_override_index = i if i < num_channel_overrides else 0 @@ -460,17 +463,10 @@ def quantize_weight_per_channel_impl( zero_point_list.append(zero_point) scale_list.append(scale) - quantized_per_channel_data_list.append(quantized_per_channel_data) + quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims)) # combine per_channel_data into one - weights_shape = list(weights.shape) - reshape_dims = list(weights_shape) # deep copy - reshape_dims[channel_axis] = 1 # only one per channel for reshape - quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims) - for i in range(1, len(quantized_per_channel_data_list)): - channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims) - quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis) - + quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis) q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale"