Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Running image preprocessing model in onnx takes significant more time #19329

Closed
arseniymerkulov opened this issue Jan 30, 2024 · 6 comments
Labels
converter related to ONNX converters core runtime issues related to core runtime performance issues related to performance regressions

Comments

@arseniymerkulov
Copy link

arseniymerkulov commented Jan 30, 2024

Describe the issue

I have a torch vision transformer model and a torch preprocessing model. I converted both of them separately to onnx format. Issue related to preprocessing model, conversion code for it given below.

After conversion, i compared inferences of model pairs (torch processor + torch transformer, onnx processor + onnx transformer) on 72 input images. Processor model processes images sequentially, both inferences were running on CPU. Total time of inferenceSession runs were measured for onnx and total time of forwards were measured for torch.

Torch preprocessing time: 0.21s
Torch classification on transformer time: 2.92s
Onnx preprocessing time: 6.29s
Onnx classification on transformer time: 1.08s

Preprocessing step on onnx takes 30x time of torch preprocessing step. I think that after conversion model may have some operators that causing performance issues/that working slow because of onnx-runtime specificity. Can you help me with that?

To reproduce

Torch preprocessing model:

class PreprocessingModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.transforms = transforms.Compose([
            transforms.Resize(size=(224, 224),
                              interpolation=transforms.InterpolationMode.BICUBIC,
                              max_size=None,
                              antialias=None),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=torch.Tensor([0.5000, 0.5000, 0.5000]),
                                 std=torch.Tensor([0.5000, 0.5000, 0.5000]))
       ])

    def forward(self, x):
        return self.transforms(x).unsqueeze(0)

Conversion to onnx for preprocessing model:

def export_preprocessing_to_onnx(self):
    processor = PreprocessingModel()
    images = glob.glob(f'{dataset_dir}/test/*/*.jpg')

    trace_input = Image.open(images[0])
    trace_input = transforms.PILToTensor()(trace_input)

    output = processor(trace_input)

    torch.onnx.export(processor,
                      trace_input,
                      processor_path,
                      export_params=True,
                      opset_version=16,
                      do_constant_folding=True,
                      input_names=['image_tensor'],
                      output_names=['preprocessed_image_tensor'],
                      dynamic_axes={'image_tensor': {1: 'image_height', 2: 'image_width'}})

Attaching onnx preprocessing model
processor.zip

Urgency

Runtime optimization for models is important for successful demonstration with clients

Platform

Windows

OS Version

22H2

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.3

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Model File

Added in "To reproduce section"

Is this a quantized model?

No

@github-actions github-actions bot added model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform labels Jan 30, 2024
@yufenglee yufenglee added core runtime issues related to core runtime and removed model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform labels Jan 31, 2024
@natke natke added the converter related to ONNX converters label Jan 31, 2024
@thiagocrepaldi
Copy link
Contributor

thiagocrepaldi commented Jan 31, 2024

One of the reasons for the performance penalty might be the presence of if-else op in the graph

image

By default, torch.onnx.export traces the graph and inline conditional nodes, only representing the branch that was executed during tracing. Was this model exported with torch.jit.script at any point?

Exporting the model with verbose=True will add the description field to each exported note containing the stack trace from where that op came from

@arseniymerkulov
Copy link
Author

arseniymerkulov commented Feb 1, 2024

Model was exported only with torch.onnx.export from native torch model class
This is output with verbose=True:

Exported graph: graph(%image_tensor : Byte(3, *, *, strides=[1, 417, 3], requires_grad=0, device=cpu),
      %onnx::Concat_40 : Long(2, strides=[1], requires_grad=0, device=cpu),
      %onnx::Clip_41 : Float(requires_grad=0, device=cpu),
      %onnx::Clip_42 : Float(requires_grad=0, device=cpu)):
  %onnx::Unsqueeze_1 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_0"]() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:549:0
  %onnx::Cast_2 : Byte(1, 3, *, *, strides=[3, 1, 417, 3], requires_grad=0, device=cpu) = onnx::Unsqueeze[onnx_name="Unsqueeze_1"](%image_tensor, %onnx::Unsqueeze_1) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:549:0
  %img : Float(1, 3, *, *, strides=[3, 1, 417, 3], requires_grad=0, device=cpu) = onnx::Cast[to=1, onnx_name="Cast_2"](%onnx::Cast_2) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:557:0
  %onnx::Slice_5 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name="Shape_3"](%img) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Slice_6 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_4"]() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Slice_7 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_5"]() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Slice_8 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="Constant_6"]() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Concat_9 : Long(2, strides=[1], device=cpu) = onnx::Slice[onnx_name="Slice_7"](%onnx::Slice_5, %onnx::Slice_7, %onnx::Slice_8, %onnx::Slice_6) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Resize_11 : Long(4, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="Concat_8"](%onnx::Concat_9, %onnx::Concat_40) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Resize_12 : Tensor? = prim::Constant() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Resize_13 : Tensor? = prim::Constant() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %onnx::Clip_14 : Float(*, *, *, *, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu) = onnx::Resize[coordinate_transformation_mode="pytorch_half_pixel", cubic_coeff_a=-0.75, mode="cubic", nearest_mode="floor", onnx_name="Resize_9"](%img, %onnx::Resize_12, %onnx::Resize_13, %onnx::Resize_11) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torch\nn\functional.py:3946:0
  %img.4 : Float(*, *, *, *, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu) = onnx::Clip[onnx_name="Clip_10"](%onnx::Clip_14, %onnx::Clip_41, %onnx::Clip_42) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:499:0
  %onnx::Gather_20 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_11"]() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:563:0
  %onnx::Gather_21 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name="Shape_12"](%img.4) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:563:0
  %onnx::Equal_22 : Long(1, strides=[1], device=cpu) = onnx::Gather[axis=0, onnx_name="Gather_13"](%onnx::Gather_21, %onnx::Gather_20) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:563:0
  %onnx::Equal_23 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={1}, onnx_name="Constant_14"]() # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:563:0
  %onnx::If_24 : Bool(1, strides=[1], device=cpu) = onnx::Equal[onnx_name="Equal_15"](%onnx::Equal_22, %onnx::Equal_23) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:563:0
  %onnx::Round_25 : Float(*, *, *, device=cpu) = onnx::If[onnx_name="If_16"](%onnx::If_24) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:563:0
    block0():
      %26 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_17"]()
      %27 : Float(*, *, *, device=cpu) = onnx::Squeeze[onnx_name="Squeeze_18"](%img.4, %26)
      -> (%27)
    block1():
      %28 : Float(*, *, *, *, device=cpu) = onnx::Identity[onnx_name="Identity_19"](%img.4)
      -> (%28)
  %onnx::Cast_29 : Float(*, *, *, strides=[50176, 224, 1], requires_grad=0, device=cpu) = onnx::Round[onnx_name="Round_20"](%onnx::Round_25) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:568:0
  %image : Byte(*, *, *, strides=[50176, 224, 1], requires_grad=0, device=cpu) = onnx::Cast[to=2, onnx_name="Cast_21"](%onnx::Cast_29) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:569:0
  %onnx::Div_31 : Float(*, *, *, strides=[50176, 224, 1], requires_grad=0, device=cpu) = onnx::Cast[to=1, onnx_name="Cast_22"](%image) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:100:0
  %onnx::Div_32 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={255}, onnx_name="Constant_23"]()
  %tensor : Float(*, *, *, strides=[50176, 224, 1], requires_grad=0, device=cpu) = onnx::Div[onnx_name="Div_24"](%onnx::Div_31, %onnx::Div_32) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:101:0
  %onnx::Sub_34 : Float(3, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu) = onnx::Constant[value=(1,.,.) =    0.5000  (2,.,.) =    0.5000  (3,.,.) =    0.5000 [ CPUFloatType{3,1,1} ], onnx_name="Constant_25"]()
  %onnx::Div_35 : Float(*, *, *, strides=[50176, 224, 1], requires_grad=0, device=cpu) = onnx::Sub[onnx_name="Sub_26"](%tensor, %onnx::Sub_34) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:959:0
  %onnx::Div_36 : Float(3, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu) = onnx::Constant[value=(1,.,.) =    0.5000  (2,.,.) =    0.5000  (3,.,.) =    0.5000 [ CPUFloatType{3,1,1} ], onnx_name="Constant_27"]()
  %onnx::Unsqueeze_37 : Float(*, *, *, strides=[50176, 224, 1], requires_grad=0, device=cpu) = onnx::Div[onnx_name="Div_28"](%onnx::Div_35, %onnx::Div_36) # C:\Users\Arsen\anaconda3\envs\ics-object-detection\lib\site-packages\torchvision\transforms\functional_tensor.py:959:0
  %onnx::Unsqueeze_38 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_29"]() # E:\wok\ics\ics-classification-train\export.py:46:0
  %preprocessed_image_tensor : Float(1, *, *, *, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu) = onnx::Unsqueeze[onnx_name="Unsqueeze_30"](%onnx::Unsqueeze_37, %onnx::Unsqueeze_38) # E:\wok\ics\ics-classification-train\export.py:46:0
  return (%preprocessed_image_tensor)

@arseniymerkulov
Copy link
Author

If op comes from resize transform and removal of it increase execution speed to 0.15s per 72 images. How can i optimize this with presence of resize layer?

@skottmckay
Copy link
Contributor

skottmckay commented Feb 1, 2024

There are tools in the onnxruntime_extensions package to do most of what you want by directly editing the ONNX model.

If you were to export the base model, you could use something like this.

import onnx
from onnxruntime_extensions.tools.pre_post_processing import *

onnx_opset = 17  # use opset 18 for Resize to antialias

model_path = "pytorch.mobilenet_v2_float.onnx"
model = onnx.load(model_path)
inputs = [create_named_value("image_tensor", onnx.TensorProto.UINT8, [3, "h", "w"])]

pipeline = PrePostProcessor(inputs, onnx_opset)

pipeline.add_pre_processing(
    [
        Resize(224, layout="CHW"),  # Uses BILINEAR currently 
        ImageBytesToFloat(),  # Convert to float in range 0..1 by dividing uint8 values by 255
        Normalize([(0.5, 0.5), (0.5, 0.5), (0.5, 0.5)]),  # (mean, stddev) for each channel 
        Unsqueeze([0]),  # add batch, CHW --> 1CHW
    ]
)

new_model = pipeline.run(model)
output_path = model_path.replace(".onnx", ".withpreprocessing.onnx")
onnx.save_model(new_model, output_path)

One issue is the Resize implementation in that tool defaults to bilinear currently as we haven't had a use-case that differed. That could be made configurable. Should be able to edit this line in the python in the package to change 'linear' to 'cubic' to achieve that. Or add this hack to the script that adds pre-processing to the model before saving the updated model.

# hack to change linear to cubic
for n in new_model.graph.node:
    if n.op_type == "Resize":
        for attr in n.attribute:
            if attr.name == "mode":
                attr.s = "cubic".encode("UTF-8")

You can also do conversion from jpg/png to bytes inside the ONNX model if you take a dependency on the onnxruntime_extensions package. That uses opencv to do the conversion. See the example usage info for an overview and these docs for info on individual pre/post processing steps.

@arseniymerkulov
Copy link
Author

I checked that using bilinear interpolation in preprocessing does not change model accuracy at all, so this is not an issue. However, i failed to use this model after I add ort-extensions pp steps to it with your code. It looks ok to me in Netron and i checked it with onnx.checker. Original model have fixed input/output shape: [1, 3, 224, 224] -> [1, 28]

When i run model with following code i get an error:

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'Add_107' Status Message: D:\a\_work\1\s\onnxruntime\core/providers/cpu/math/element_wise_ops.h:560 onnxruntime::BroadcastIterator::Append axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 197 by 477

Code:

  image = Image.open("...")
  inputs = transforms.PILToTensor()(image)
  inputs = inputs.numpy()
  outputs = self.session.run(None, {'image': inputs})[0]  

In code for modifying model with ort-extensions i changed "inputs" name from "image_tensor" to "image", seems like Resize layer have fixed input name and model input should match it. I tried different policies in Resize as well.
Attaching onnx model with preprocessing
model.fixed.withpreprocessing.onnx.zip

@arseniymerkulov
Copy link
Author

Created issue in onnxruntime-extensions for this, as more specific

@sophies927 sophies927 added the performance issues related to performance regressions label Feb 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
converter related to ONNX converters core runtime issues related to core runtime performance issues related to performance regressions
Projects
None yet
Development

No branches or pull requests

6 participants