Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] NnapiExecutionProvider defers all nodes to CPUExecutionProvider #18571

Closed
StefanWenninger opened this issue Nov 23, 2023 · 6 comments
Labels
platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization

Comments

@StefanWenninger
Copy link

Describe the issue

Goal

We are trying to utilize the NPU of our NXP i.MX8MP SoC. Our goal is to accelerate several ONNX models that previously ran on the CPU. For that we want to establish a pipeline to convert our non-quantized, FP32, Pytorch-based ONNX models to models we can run on the NPU of our SoC using the NnapiExecutionProvider of onnxruntime.

Setup

  • OS: Linux Yocto Mickledore 4.2
  • onnxruntime: onnxruntime-imx 1.13.1 from nxp-imx via Yocto recipe
  • SOC: NXP i.MX8MP

Issue

Development of the pipeline is based on an example of Microsoft. This example first exports a PyTorch trained mobilenet V2 model into a non-quantized, FP32 ONNX Model. Afterwards it quantizes the Model into UInt8.

I tried to follow this example with the exact same mobilenet V2 model to get a UInt8 quantized Model that can run efficiently on the NPU. However the resulting model is exactly as fast as on CPU. When I use a mobilenet V2 model from the internet that was already UInt8 quantized I do observe a significant performance improvement over the CPU.

My question therefore is: Am I doing something wrong in my quantization/conversion pipeline?

Comparison of the two models

Reference: mobilenet_v2_uint8.v5.ort

  • Source: Downloaded uint8 ORT Model from Mobile image recognition on Android
  • Performance: CPU inference: 61ms | NPU inference: 22ms
  • Quantization: reportedly UInt8 quantized via mobilenet.ipynb
  • Representation: operator-oriented format (QOperator)
  • Node placement: onnxruntime places almost all nodes on the NnapiExecutionProvider. Only one node (QLinearGlobalAveragePool) gets placed on the CPUExecutionProvider.

Our model: mobilenet_v2_infer_uint8_op_oriented.onnx

  • Source: Created by us via prototype pipeline
  • Performance: CPU inference: 54ms | NPU inference: 54ms
  • Quantization: UInt8 quantized via mobilenet.ipynb (see attached quantize_float_to_uint8.py)
    • changed parameter opset_version of torch.onnx.export from 12 to 17
    • changed parameter quant_format of quantize_static from QuantFormat.QDQ to QuantFormat.QOperator
  • Representation: operator-oriented format (QOperator)
  • Node placement: All nodes are placed on the CPUExecutionProvider

Observations

  • The identical performance of our model on CPU and NPU makes sense, since the NnapiExecutionProvider defers all nodes to the CPUExecutionProvider.
  • The NPU execution log of the reference model notes the existence of 2 graph partitions. Additionally 2 Nnapi_... nodes are shown to be placed on the NnapiExecutionProvider. As far as I know the NnapiExecutionProvider compiles the model for execution on the NPU. I therefore assume the two Nnapi_... nodes are the two compiled graph partitions.
  • Comparison of the models in netron.app:
    • Basic structure is identical
    • Used nodes are mostly identical
    • Tensor sizes are exactly identical
    • Some parameter values are different
      • zero_point in our model is often very different (sometimes even negative)
    • At the end of our model some nodes differ from the reference model:
      • Flatten instead of Reshape
      • QGemm instead of QLinearMatMul and QLinearAdd

Attachments

Models

Contains mobilenet_v2_uint8.v5.ort and mobilenet_v2_infer_uint8_op_oriented.onnx.
mobilenet_v2_models.zip
mobilenet_v2_uint8.v5.ort.png
mobilenet_v2_infer_uint8_op_oriented.onnx.png

Execution logs

exec_log_our_model_cpu.txt
exec_log_our_model_npu.txt
exec_log_reference_model_cpu.txt
exec_log_reference_model_npu.txt

Python scripts

Contains quantize_float_to_uint8.py and run_onnx_sample.py.
python_scripts.zip

To reproduce

  1. Execute python3 quantize_float_to_uint8.py:
# This script combines all snippets from https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/notebooks/imagenet_v2/mobilenet.ipynb
from torchvision import models, datasets, transforms as T
mobilenet_v2 = models.mobilenet_v2(pretrained=True)


import torch
image_height = 224
image_width = 224
x = torch.randn(1, 3, image_height, image_width, requires_grad=True)
torch_out = mobilenet_v2(x)
torch.onnx.export(mobilenet_v2, x, "mobilenet_v2_float.onnx", export_params=True, opset_version=17, do_constant_folding=True, input_names=['input'], output_names=['output'])

# pre-process (python3 -m onnxruntime.quantization.preprocess --input mobilenet_v2_float.onnx --output mobilenet_v2_infer_float.onnx)
from onnxruntime.quantization.shape_inference import quant_pre_process
quant_pre_process('mobilenet_v2_float.onnx', 'mobilenet_v2_infer_float.onnx')

from PIL import Image
import numpy as np
import onnxruntime
import torch

def preprocess_image(image_path, height, width, channels=3):
    image = Image.open(image_path)
    image = image.resize((width, height), Image.LANCZOS)
    image_data = np.asarray(image).astype(np.float32)
    image_data = image_data.transpose([2, 0, 1])
    mean = np.array([0.079, 0.05, 0]) + 0.406
    std = np.array([0.005, 0, 0.001]) + 0.224
    for channel in range(image_data.shape[0]):
        image_data[channel, :, :] = (image_data[channel, :, :] / 255 - mean[channel]) / std[channel]
    image_data = np.expand_dims(image_data, 0)
    return image_data


with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]


session_fp32 = onnxruntime.InferenceSession("mobilenet_v2_infer_float.onnx", providers=['CPUExecutionProvider'])

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def run_sample(session, image_file, categories):
    output = session.run([], {'input': preprocess_image(image_file, image_height, image_width)})[0]
    output = output.flatten()
    output = softmax(output)
    top5_catid = np.argsort(-output)[:5]
    for catid in top5_catid:
        print(categories[catid], output[catid])

run_sample(session_fp32, 'cat.jpg', categories)


from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantFormat
import os

def preprocess_func(images_folder, height, width, size_limit=0):
    image_names = os.listdir(images_folder)
    if size_limit > 0 and len(image_names) >= size_limit:
        batch_filenames = [image_names[i] for i in range(size_limit)]
    else:
        batch_filenames = image_names
    unconcatenated_batch_data = []
    for image_name in batch_filenames:
        image_filepath = images_folder + '/' + image_name
        image_data = preprocess_image(image_filepath, height, width)
        print(f'shape of pre-processed image_data: {image_data.shape}')
        unconcatenated_batch_data.append(image_data)
    batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0)
    return batch_data

class MobilenetDataReader(CalibrationDataReader):
    def __init__(self, calibration_image_folder):
        self.image_folder = calibration_image_folder
        self.preprocess_flag = True
        self.enum_data_dicts = None
        self.datasize = 0
    def get_next(self):
        if self.preprocess_flag:
            self.preprocess_flag = False
            nhwc_data_list = preprocess_func(self.image_folder, image_height, image_width, size_limit=0)
            print(f'type nhwc_data_list: {type(nhwc_data_list)}')
            print(f'nhwc_data_list shape: {nhwc_data_list.shape}')
            self.datasize = len(nhwc_data_list)
            list_for_iterator = [{'input': nhwc_data} for nhwc_data in nhwc_data_list]
            self.enum_data_dicts = iter(list_for_iterator)
        return next(self.enum_data_dicts, None)


calibration_data_folder = "calibration_imagenet"
dr = MobilenetDataReader(calibration_data_folder)
quantize_static(model_input='mobilenet_v2_infer_float.onnx',
                model_output='mobilenet_v2_infer_uint8_op_oriented.onnx',
                calibration_data_reader=dr,
                quant_format=QuantFormat.QOperator)

print(f'ONNX full precision model size (MB): {os.path.getsize("mobilenet_v2_infer_float.onnx")/(1024*1024)}')
print(f'ONNX quantized model size (MB): {os.path.getsize("mobilenet_v2_infer_uint8_op_oriented.onnx")/(1024*1024)}')
  1. Execute python3 run_onnx_sample.py --model mobilenet_v2_infer_uint8_op_oriented.onnx --image cat.jpg --pu npu:
import onnxruntime
import numpy as np
from PIL import Image
import time
import math
import argparse


image_height = 224
image_width = 224

def preprocess_image(image_path, height, width, channels=3):
    image = Image.open(image_path)
    image = image.resize((width, height), Image.LANCZOS)
    image_data = np.asarray(image).astype(np.float32)
    image_data = image_data.transpose([2, 0, 1])
    mean = np.array([0.079, 0.05, 0]) + 0.406
    std = np.array([0.005, 0, 0.001]) + 0.224
    for channel in range(image_data.shape[0]):
        image_data[channel, :, :] = (image_data[channel, :, :] / 255 - mean[channel]) / std[channel]
    image_data = np.expand_dims(image_data, 0)
    return image_data

with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def run_sample(session, image_file, categories):
    start_run = time.time()
    output = session.run([], {'input':preprocess_image(image_file, image_height, image_width)})[0]
    end_run = time.time()
    output = output.flatten()
    output = softmax(output) # this is optional
    top5_catid = np.argsort(-output)[:5]
    for catid in top5_catid:
        print(categories[catid], output[catid])
    print(f'#################### execution time: {math.floor((end_run-start_run)*1000)}ms')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help='Model to be inferred')
    parser.add_argument('--image', help='Image to run inference on')
    parser.add_argument('--pu', help='Processing unit to run inference with')
    args = parser.parse_args()

    opts = onnxruntime.SessionOptions()
    opts.enable_profiling = True
    opts.log_severity_level = 0
    opts.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
    opts.intra_op_num_threads = 3
    opts.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    if args.model.split('.')[-1] == 'ort':
        opts.add_session_config_entry('session.load_model_format', 'ORT')
    else:
        opts.add_session_config_entry('session.load_model_format', 'ONNX')

    if args.pu == 'cpu':
        session_quant = onnxruntime.InferenceSession(args.model, sess_options=opts, providers=['CPUExecutionProvider'])
    elif args.pu == 'npu':
        session_quant = onnxruntime.InferenceSession(args.model, sess_options=opts, providers=['NnapiExecutionProvider'])

    print(f'#################### warmup')
    run_sample(session_quant, args.image, categories)
    print(f'#################### run')
    run_sample(session_quant, args.image, categories)
    print(f'#################### run')
    run_sample(session_quant, args.image, categories)
    print(f'#################### run')
    run_sample(session_quant, args.image, categories)

Necessary scripts and files are in reproduce.zip

Urgency

No response

Platform

Linux

OS Version

Linux Yocto Mickledore 4.2

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.13.1

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

NNAPI

Execution Provider Library Version

No response

Model File

mobilenet_v2_models.zip

Is this a quantized model?

Yes

@github-actions github-actions bot added platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization labels Nov 23, 2023
@skottmckay
Copy link
Contributor

QuantFormat.QOperator uses custom operators that are not implemented by all execution providers.

QDQ format is more generic as it uses official ONNX operators that are wrapped in DQ/Q nodes that allows an EP like the NNAPI EP to convert those to the quantized equivalent.

Is there a reason not to use QDQ format? I believe the CPU EP should be able to convert the QDQ node units into the equivalent QOperator at runtime.

@StefanWenninger
Copy link
Author

Since the reference model from the web seems to actually experience NPU acceleration I tried to align our model with the reference. And the reference model happens to be in QOperator format.
Ours was in QDQ format previously. When trying to execute that model on the NnapiExecutionProvider onnxruntime tells us nearly all nodes are placed on the CPUExecutionProvider (see exec_log_our_model_npu_qdq.txt).

So really I am just trying to reduce the number of differences between our model and the reference in hopes of getting our model to run efficiently on the NPU.

From my understanding the custom operators introduced by QOperator format should be implemented for the NnapiExecutionProvider on my platform since the reference model does run on the NPU with mostly the same custom operators. Of course it could be that the two operators that are only in our model (Flatten and QGemm) are not implemented in NnapiExecutionProvider. But still then I would expect at least some of the rest of the operators to be placed on the NnapiExecutionProvider. From what I can tell from the logs, onnxruntime does not even try to place any nodes on the NnapiExecutionProvider.

I have also tried converting our .onnx model into an .ort model. Unfortunately the model did not show any improvement, neither with the --optimization_level fixed nor with the --optimization_level runtime.

I also stumbled across a PyTorch tutorial detailing how to prepare PyTorch model for NNAPI Execution. Could something like this be necessary?

@edgchen1
Copy link
Contributor

I was able to reproduce this with the mobilenet_v2_infer_uint8_op_oriented.onnx model generated by quantize_float_to_uint8.py.

FYI, there is more log output from the NNAPI EP that goes to the default logger. From Python, you can enable this with onnxruntime.set_default_logger_severity(0).

With the default logger's verbose output, we can see some info about why the ops are not supported:

11-27 16:51:41.486  2119  2119 V onnxruntime:  [V:onnxruntime:Default, helper.cc:177 HasValidBinaryOpQuantizedInputTypes] [QLinearConv] A Input type: [3] B Input type: [3] is not supported for now
11-27 16:51:41.486  2119  2119 V onnxruntime:  [V:onnxruntime:Default, nnapi_execution_provider.cc:151 operator()] Node supported: [0] Operator type: [QLinearConv] index: [1] name: [/features/features.0/features.0.0/Conv_quant] as part of the NodeUnit type: [QLinearConv] index: [1] name: [/features/features.0/features.0.0/Conv_quant]

Input type 3 corresponds to int8.
https://github.com/onnx/onnx/blob/5be7f3164ba0b2c323813264ceb0ae7e929d2350/onnx/onnx.in.proto#L490

Looks like uint8 would be supported.

// QlinearConv/MatMul/QDQGemm/QDQMatMul supports u8u8 or u8s8
// QLinearAdd/QLinearMul only support u8u8
bool is_quant_conv_or_gemm = IsQuantizedConv(quant_op_type) || IsQuantizedGemm(quant_op_type);
bool has_valid_qlinear_conv_weight =
(b_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
b_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8);
if (a_input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
(!is_quant_conv_or_gemm && a_input_type != b_input_type) ||
(is_quant_conv_or_gemm && !has_valid_qlinear_conv_weight)) {
LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType()
<< "] A Input type: [" << a_input_type
<< "] B Input type: [" << b_input_type
<< "] is not supported for now";
return false;
}

The mobilenet_v2_uint8.v5.ort model does have uint8 inputs for the first QLinearConv.

@StefanWenninger
Copy link
Author

You struck gold with that observation!

First of all thanks for the hint with onnxruntime.set_default_logger_severity(0). I did not know of this log possibility. The output of that has already proven useful.

You were right in assuming UInt8 might solve the issue. I discovered that the main function used for quantization does allow setting the activation and weight type:

def quantize_static(
    model_input,
    model_output,
    calibration_data_reader: CalibrationDataReader,
    quant_format=QuantFormat.QDQ,
    op_types_to_quantize=None,
    per_channel=False,
    reduce_range=False,
    activation_type=QuantType.QInt8,
    weight_type=QuantType.QInt8,
    nodes_to_quantize=None,
    nodes_to_exclude=None,
    optimize_model=True,
    use_external_data_format=False,
    calibrate_method=CalibrationMethod.MinMax,
    extra_options=None,
):

QuantType offers the option of QUInt8:

class QuantType(Enum):
    QInt8 = 0
    QUInt8 = 1

Setting those two parameters to QuantType.QUInt8 produced a .onnx model that is of UInt8 type:

quantize_static(model_input='mobilenet_v2_infer_float.onnx',
                model_output='mobilenet_v2_infer_uint8_op_oriented.onnx',
                calibration_data_reader=dr,
                quant_format=QuantFormat.QOperator,
                activation_type=QuantType.QUInt8,
                weight_type=QuantType.QUInt8)

With that true UInt8 model I achieve CPU inference in 62ms and NPU inference in 23ms.
A few observations about our new model mobilenet_v2_infer_uint8_op_oriented.zip:

  • The negative zero_point parameters in netron.app are gone. Now these values are identical or at least similar to the reference model.
  • With our new model 3 graph partitions are placed on the NnapiExecutionProvider in compiled nodes. QLinearGlobalAveragePool is placed on the CPUExecutionProvider (as it was for the reference model). Additionally QGemm is placed on the CPUExecutionProvider. I assume that node does not have an NNAPI implementation.

I should now have a working pipeline to convert our existing PyTorch based .onnx models into UInt8 models that are ready to be run on the NnapiExecutionProvider. I will get to testing that asap!

Thank you for pointing me in the right direction to solve this issue and thanks to everyone for your helpful insights!
For now I consider this issue closed.

@edgchen1
Copy link
Contributor

Great! I'll go ahead and close this issue. Feel free to open a new one if you have other questions.

I didn't find docs for onnxruntime.set_default_logger_severity() and other module level functions, we should probably add some.

@edgchen1
Copy link
Contributor

And yes, from a quick look at the code I believe QGemm and QLinearGlobalAveragePool are not currently supported by the NNAPI EP. That said, there may be more support (at least for Gemm) in a QDQ format model, so that may be worth trying too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants