Skip to content

Commit

Permalink
Improve speed in combining per-channel data (#21563)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Improve speed in combining `per-channel` data for using a single
`np.concatenate` instead of multiple `np.concatenates` within a for
loop.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Fix the issue #21562

Signed-off-by: duansheng.liu <[email protected]>
  • Loading branch information
duanshengliu authored Aug 6, 2024
1 parent 4ad87ca commit b95aa05
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def quantize_weight_per_channel_impl(
zero_point_list = []
scale_list = []
quantized_per_channel_data_list = []
weights_shape = list(weights.shape)
reshape_dims = list(weights_shape) # deep copy
reshape_dims[channel_axis] = 1 # only one per channel for reshape
for i in range(channel_count):
per_channel_data = weights.take(i, channel_axis)
channel_override_index = i if i < num_channel_overrides else 0
Expand Down Expand Up @@ -460,17 +463,10 @@ def quantize_weight_per_channel_impl(

zero_point_list.append(zero_point)
scale_list.append(scale)
quantized_per_channel_data_list.append(quantized_per_channel_data)
quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))

# combine per_channel_data into one
weights_shape = list(weights.shape)
reshape_dims = list(weights_shape) # deep copy
reshape_dims[channel_axis] = 1 # only one per channel for reshape
quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims)
for i in range(1, len(quantized_per_channel_data_list)):
channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)

quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
zp_name = weight_name + "_zero_point"
scale_name = weight_name + "_scale"
Expand Down

0 comments on commit b95aa05

Please sign in to comment.