Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running PRelu node #3205

Closed
cannguyen275 opened this issue Mar 13, 2020 · 7 comments

Comments

@cannguyen275
Copy link

cannguyen275 commented Mar 13, 2020

Describe the bug
I'm trying to run a model, which is converted from Mxnet to ONNX.

Traceback (most recent call last):
  File "/home/my_project/recognize_onnx.py", line 31, in <module>
    out = ort_session.run([outputs], input_feed={input_name: input_blob})
  File "/home/miniconda3/envs/onnx13/lib/python3.7/site-packages/onnxruntime/capi/session.py", line 142, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running PRelu node. Name:'conv_1_relu' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:339 void onnxruntime::BroadcastIterator::Init(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 56 by 64
Stacktrace:

. My script is:

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 18.04
  • ONNX Runtime installed from (source or binary): binary
  • ONNX Runtime version: 1.1.2
  • Python version: 3.7

To Reproduce
You can get my model in link

predictor = onnx.load("mymodel.onnx")
print(onnx.checker.check_model(predictor))
ort_session = ort.InferenceSession(onnx_path)
input_name = ort_session.get_inputs()[0].name
outputs = ort_session.get_outputs()[0].name
img = cv2.imread("myImage.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (112, 112))
aligned = np.transpose(img, (2, 0, 1))
input_blob = np.expand_dims(aligned, axis=0).astype(np.float32)
out = ort_session.run([outputs], input_feed={input_name: input_blob})

How can I fix this error? Any advices will be awesome!
Thanks

@hariharans29
Copy link
Member

hariharans29 commented Mar 13, 2020

Well the root cause is that there seems to have been a problem with the export to ONNX -

image

The output shape of the marked Conv node will be - [1, 64, 56, 56] and BatchNorm will propagate that shape to PRelu. The shape of the slope tensor in PRelu should be broadcastable to the input shape. - https://github.com/onnx/onnx/blob/master/docs/Operators.md#prelu. [64] is not broadcastable to [1, 64, 56, 56].

Even if this issue is fixed, the model exported from MxNet with BatchNorm will not run. Please see the following for details -
onnx/models#156
apache/mxnet#17711

Closing as this is not a runtime issue.

@ilovewangzeyu
Copy link

Hello. Hvae you solved your problem?

@cannguyen275
Copy link
Author

@ilovewangzeyu I have not.
MXNet seems to no get along well with onnx then I decide to use pytorch version of my model. Then it works!

@duonglong289
Copy link

The root cause is as @hariharans29 said. I found this link change the function convert PReLU from mxnet to onnxruntime.
It can fix this bug of mxnet converter.
After that, the model exported from MxNet with BatchNorm might not run because the "spatial=0" in BatchNormalization. Following this link

Another way, i wrote a script for convert the exported model from mxnet to onnx to add a Reshape layer before BatchNormalization layer and it works for me.

import onnx
from onnx import checker
import logging

model = onnx.load(r"mxnet2onnx_exported_bug_model.onnx")
onnx_processed_nodes = []
onnx_processed_inputs = []
onnx_processed_outputs = []
onnx_processed_initializers = []

reshape_node = []

for ind, node in enumerate(model.graph.node):
    if node.op_type == "PRelu":
        input_node = node.input
        input_bn = input_node[0]
        input_relu_gamma = input_node[1]
        output_node = node.output[0]
    
        input_reshape_name = "reshape{}".format(ind)
        slope_number = "slope{}".format(ind)


        node_reshape = onnx.helper.make_node(
            op_type="Reshape",
            inputs=[input_relu_gamma, input_reshape_name],
            outputs=[slope_number],
            name=slope_number
        )

        reshape_node.append(input_reshape_name)
        node_relu = onnx.helper.make_node(
            op_type="PRelu",
            inputs=[input_bn, slope_number],
            outputs=[output_node],
            name=output_node
        )
        onnx_processed_nodes.extend([node_reshape, node_relu])

    
    else:
        # If "spatial = 0" does not work for "BatchNormalization", change "spatial=1"
        # else comment this "if" condition
        if node.op_type == "BatchNormalization":
            for attr in node.attribute:
                if (attr.name == "spatial"):
                    attr.i = 1
        onnx_processed_nodes.append(node)


list_new_inp = []
list_new_init = []
for name_rs in reshape_node:
    new_inp = onnx.helper.make_tensor_value_info(
        name=name_rs,
        elem_type=onnx.TensorProto.INT64,
        shape=[4]
    )
    new_init = onnx.helper.make_tensor(
        name=name_rs,
        data_type=onnx.TensorProto.INT64,
        dims=[4],
        vals=[1, -1, 1, 1]
    )
    
    list_new_inp.append(new_inp)
    list_new_init.append(new_init)
    

for k, inp in enumerate(model.graph.input):
    if "relu0_gamma" in inp.name or "relu1_gamma" in inp.name: #or "relu_gamma" in inp.name:
        new_reshape = list_new_inp.pop(0)
        onnx_processed_inputs.extend([inp, new_reshape])
    else:     
        onnx_processed_inputs.extend([inp])

for k, outp in enumerate(model.graph.output):
    onnx_processed_outputs.extend([outp])

for k, init in enumerate(model.graph.initializer):
    if "relu0_gamma" in init.name or "relu1_gamma" in init.name:
        new_reshape = list_new_init.pop(0)
        onnx_processed_initializers.extend([init, new_reshape])
    else:
        onnx_processed_initializers.extend([init])


graph = onnx.helper.make_graph(
        onnx_processed_nodes,
        "mxnet_converted_model",
        onnx_processed_inputs,
        onnx_processed_outputs
    )

graph.initializer.extend(onnx_processed_initializers)

# Check graph
checker.check_graph(graph)

onnx_model = onnx.helper.make_model(graph)

# Write model
str_input = '3,112,112'
input_shape = (1,) + tuple( [int(x) for x in str_input.split(',')] )
onnx_file_path = "mxnet2onnx_model_onnxruntime.onnx"

with open(onnx_file_path, "wb") as file_handle:
    serialized = onnx_model.SerializeToString()
    file_handle.write(serialized)
    logging.info("Input shape of the model %s ", input_shape)
    logging.info("Exported ONNX file %s saved to disk", onnx_file_path)

print("Done!!!")

@KokinSok
Copy link

KokinSok commented Mar 8, 2024

Every angle, always come back to Onnx Model error. Same error here. Onnx just doesnt work!

@imotai
Copy link

imotai commented Mar 19, 2024

Every angle, always come back to Onnx Model error. Same error here. Onnx just doesnt work!

So sad to come here and I still don't know how to solve this problem

@thewh1teagle
Copy link

Well the root cause is that there seems to have been a problem with the export to ONNX -

Any idea where am I wrong in the export here?
#21805
This is how I export the model: export-onnx.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants