Skip to content

Commit

Permalink
Merge branch 'improve-combine-per-channel-data-speed' of github.com:d…
Browse files Browse the repository at this point in the history
…uanshengliu/onnxruntime into improve-combine-per-channel-data-speed
  • Loading branch information
duanshengliu committed Aug 5, 2024
2 parents 88c811b + 00fcfdb commit ef8c09e
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def quantize_weight_per_channel_impl(
zero_point_list = []
scale_list = []
quantized_per_channel_data_list = []
weights_shape = list(weights.shape)
reshape_dims = list(weights_shape) # deep copy
reshape_dims[channel_axis] = 1 # only one per channel for reshape
for i in range(channel_count):
per_channel_data = weights.take(i, channel_axis)
channel_override_index = i if i < num_channel_overrides else 0
Expand Down Expand Up @@ -460,17 +463,10 @@ def quantize_weight_per_channel_impl(

zero_point_list.append(zero_point)
scale_list.append(scale)
quantized_per_channel_data_list.append(quantized_per_channel_data)
quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))

# combine per_channel_data into one
weights_shape = list(weights.shape)
reshape_dims = list(weights_shape) # deep copy
reshape_dims[channel_axis] = 1 # only one per channel for reshape
quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims)
for i in range(1, len(quantized_per_channel_data_list)):
channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)

quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
zp_name = weight_name + "_zero_point"
scale_name = weight_name + "_scale"
Expand Down

0 comments on commit ef8c09e

Please sign in to comment.