Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to quantize ConvTranspose with per_channel=True #19694

Closed
mo-ja opened this issue Feb 28, 2024 · 0 comments · Fixed by #19996
Closed

Failed to quantize ConvTranspose with per_channel=True #19694

mo-ja opened this issue Feb 28, 2024 · 0 comments · Fixed by #19996
Labels
quantization issues related to quantization

Comments

@mo-ja
Copy link
Contributor

mo-ja commented Feb 28, 2024

Describe the issue

ConvTranspose layer cannot be quantized with per_channel=True.

Screenshot from 2024-02-28 23-03-29

In the case of ConvTranspose, the axis of weight, axis=1, corresponds to the number of channels in the output, whereas QDQConv's per_channel quantization always quantizes on the axis of axis=0.
This seems to cause an error during the bias scale calculation in ConvTranspose because the shape of bias_scale does not match the shape of bias.
I have solved this issue by using different axis for Conv and ConvTranspose like here:

To reproduce

Here is the onnx file I used: onnx.zip

model generation with torch

import torch
import torch.nn as nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv_transpose(x)
        return x

model = MyModel()
input_data = torch.randn(1, 3, 32, 32)

torch.onnx.export(model, input_data, "model.onnx",opset_version=13)

quantization code with ONNX Runtime

import onnxruntime.quantization as quantization
import numpy as np
class SampleReader(quantization.CalibrationDataReader):
    def __init__(self):
       self.flg = False

    def get_next(self):
      if self.flg:
         return None
      self.flg = True
      return {"onnx::ConvTranspose_0":np.random.uniform(0, 1, [1, 3, 32, 32]).astype(np.float32)}
    
cdr = SampleReader()
quantization.quantize_static("./model.onnx", "quant.onnx", cdr, per_channel=True)

Raised Error

ValueError                                Traceback (most recent call last)
[/tmp/ipykernel_833326/1616630263.py](https://file+.vscode-resource.vscode-cdn.net/tmp/ipykernel_833326/1616630263.py) in 
     12 
     13 cdr = SampleReader()
---> 14 quantization.quantize_static("./model.onnx", "quant.onnx", cdr, per_channel=True)

[~/.local/lib/python3.8/site-packages/onnxruntime/quantization/quantize.py](https://file+.vscode-resource.vscode-cdn.net/home/<user name>/work/ort_report/001_ConvTranspose/torch/~/.local/lib/python3.8/site-packages/onnxruntime/quantization/quantize.py) in quantize_static(model_input, model_output, calibration_data_reader, quant_format, op_types_to_quantize, per_channel, reduce_range, activation_type, weight_type, nodes_to_quantize, nodes_to_exclude, use_external_data_format, calibrate_method, extra_options)
    535         )
    536 
--> 537     quantizer.quantize_model()
    538     quantizer.model.save_model_to_file(model_output, use_external_data_format)
    539     if not pre_processed:

[~/.local/lib/python3.8/site-packages/onnxruntime/quantization/qdq_quantizer.py](https://file+.vscode-resource.vscode-cdn.net/home/<user name>/work/ort_report/001_ConvTranspose/torch/~/.local/lib/python3.8/site-packages/onnxruntime/quantization/qdq_quantizer.py) in quantize_model(self)
    264         self._quantize_sharing_param_tensors()
    265         if self.quantize_bias:
--> 266             self._quantize_bias_tensors()
    267         self.remove_nodes()
    268         if not self.add_qdq_pair_to_weight:

[~/.local/lib/python3.8/site-packages/onnxruntime/quantization/qdq_quantizer.py](https://file+.vscode-resource.vscode-cdn.net/home/<user name>/work/ort_report/001_ConvTranspose/torch/~/.local/lib/python3.8/site-packages/onnxruntime/quantization/qdq_quantizer.py) in _quantize_bias_tensors(self)
    480                 continue
    481             # Quantize the input
--> 482             self.quantize_bias_static(bias_name, input_name, weight_name, beta)
    483             init = find_by_name(bias_name, self.model.initializer())
    484             self.model.remove_initializer(init)

[~/.local/lib/python3.8/site-packages/onnxruntime/quantization/onnx_quantizer.py](https://file+.vscode-resource.vscode-cdn.net/home/<user name>/work/ort_report/001_ConvTranspose/torch/~/.local/lib/python3.8/site-packages/onnxruntime/quantization/onnx_quantizer.py) in quantize_bias_static(self, bias_name, input_name, weight_name, beta)
    917             bias_scale = input_scale * weight_scale * beta
    918 
--> 919             quantized_data = (np.asarray(bias_data) [/](https://file+.vscode-resource.vscode-cdn.net/) bias_scale).round().astype(np.int32)
    920 
    921             # update bias initializer

ValueError: operands could not be broadcast together with shapes (64,) (3,)

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04.5 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.1

ONNX Runtime API

Python

Architecture

X86

Execution Provider

Default CPU

Execution Provider Library Version

No response

Tasks

Preview Give feedback
No tasks being tracked yet.
@mo-ja mo-ja changed the title Failed to quantize ConVTranspose with per_channel=True Failed to quantize ConvTranspose with per_channel=True Feb 28, 2024
@wangyems wangyems added the quantization issues related to quantization label Feb 28, 2024
adrianlizarraga pushed a commit that referenced this issue Mar 30, 2024
### Description
<!-- Describe your changes. -->
 - update axis value for per_channel quantization of QDQConv
   - we should use `axis=1` for ConvTranspose operator.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
- this PR fixes #19694,
which I have opened
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this issue May 7, 2024
…soft#19996)

### Description
<!-- Describe your changes. -->
 - update axis value for per_channel quantization of QDQConv
   - we should use `axis=1` for ConvTranspose operator.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
- this PR fixes microsoft#19694,
which I have opened
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization issues related to quantization
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants