Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models with multiple outputs produce different results when the order of irrelevant lines are changed #18081

Open
Azyka opened this issue Oct 24, 2023 · 2 comments
Labels
core runtime issues related to core runtime ep:tvm issues related to TVM execution provider stale issues that have not been addressed in a while; categorized by a bot

Comments

@Azyka
Copy link

Azyka commented Oct 24, 2023

Describe the issue

On opset version 14, when organizing the nodes related to multiple outputs in a different order, models supposed the same produce different outputs in onnxruntime. The result is correct when reserving either one of the outputs but goes wrong when there are 2 outputs. The multiple outputs divergence was found in processing fp16 data. And such divergence was not seen in tvm.
Generally, the bug seems to need several conditions:

  1. enough fp16 inputs in scale
  2. multiple outputs
  3. specific operations

To reproduce

Test models in ort with and without optimization:

import onnxruntime as ort
import onnx
import numpy as np
from numpy import testing
import torch

class Model0(torch.nn.Module):
    def forward(self, *args):
        getitem = args[0]
        getitem_1 = args[1]
        div = torch.div(getitem_1, getitem)
        expand = div.expand(1, 2, 54)
        max_2 = torch.max(expand, getitem_1)
        gt = torch.gt(getitem_1, div)
        return (gt, max_2)

class Model1(torch.nn.Module):
    def forward(self, *args):
        getitem = args[0]
        getitem_1 = args[1]
        div = torch.div(getitem_1, getitem)
        gt = torch.gt(getitem_1, div)
        expand = div.expand(1, 2, 54)
        max_2 = torch.max(expand, getitem_1)
        return (gt, max_2)

model_0 = Model0()
input_data_0 = np.array([[3.62 ], [6.273]], dtype=np.float16)
input_data_1 = np.array([[3.617, 6.312, 5.45 , 6.28 , 5.363, 6.945, 6.03 , 4.82 , 5.438,
        4.21 , 3.969, 3.49 , 6.93 , 3.854, 6.652, 4.086, 6.33 , 4.336,
        5.246, 6.1  , 4.816, 5.76 , 6.637, 6.984, 5.83 , 5.242, 4.695,
        3.457, 4.273, 4.465, 5.617, 6.664, 3.53 , 6.12 , 3.74 , 6.57 ,
        5.516, 6.758, 6.71 , 6.902, 3.352, 4.44 , 4.008, 3.443, 3.803,
        4.844, 3.918, 3.645, 5.613, 4.36 , 5.02 , 3.766, 6.805, 4.312]], dtype=np.float16)
input_dict_0 = {'v5_0':input_data_0, 'v4_0':input_data_1}
inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_0.items())
torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v5_0', 'v4_0'], output_names=['v3_0', 'v1_0'], opset_version=16, do_constant_folding=False)

model_1 = Model1()
input_dict_1 = {'v0_0':input_data_0, 'v1_0':input_data_1}
inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_1.items())
torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0', 'v1_0'], output_names=['v4_0', 'v7_0'], opset_version=16, do_constant_folding=False)


# Run in ort
output_names_0 = ['v1_0', 'v3_0']
sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

output_names_1 = ['v4_0', 'v7_0']
sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))
output_name_dict = {'v3_0': 'v4_0', 'v1_0': 'v7_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_enable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_enable_opt triggers assertion")
    print(e)
print('=========================')

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_disable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_disable_opt triggers assertion")
    print(e)

I simply changed the order of these lines in models from:

expand = div.expand(1, 2, 54)
max_2 = torch.max(expand, getitem_1)
gt = torch.gt(getitem_1, div)

to:

gt = torch.gt(getitem_1, div)
expand = div.expand(1, 2, 54)
max_2 = torch.max(expand, getitem_1)

and nothing should be changed in execution.
However I got the output as follows:

onnxruntime_enable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 32 / 108 (29.6%)
 x: array([[False, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False,
         True,  True,  True, False,  True,  True,  True,  True,  True,...
 y: array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,...
=========================
=========================
onnxruntime_disable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 26 / 108 (24.1%)
 x: array([[ True, False, False, False, False, False, False, False,  True,
         True, False, False, False, False, False, False,  True, False,
        False, False,  True, False, False, False,  True,  True,  True,...
 y: array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,...

Urgency

This is an incorrect functionality implementation. It may cause severe bugs for those systems on the top of ORT.

Platform

Linux

OS Version

Ubuntu 22.04.3 LTS (x86_64)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@github-actions github-actions bot added the ep:tvm issues related to TVM execution provider label Oct 24, 2023
@Azyka
Copy link
Author

Azyka commented Oct 26, 2023

A similar case, which also contains the operation max, maybe the bug lays in max/min?

import onnxruntime as ort
import onnx
import numpy as np
from numpy import testing
import tvm
from tvm import relay
import torch

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        getitem_1 = _args[1];  _args = None
        ceil = torch.ceil(getitem)
        to = ceil.to(dtype = torch.int32)
        abs_1 = torch.abs(ceil)
        max_1 = torch.max(getitem_1, abs_1)
        return (to, max_1)

model_0 = Model0()
input_data_0 = np.array([5.766], dtype=np.float16)
input_data_1 = np.array([[[[6.566, 4.465, 6.43 ]]],
       [[[4.63 , 6.03 , 6.438]]],
       [[[6.797, 4.367, 6.387]]],
       [[[6.76 , 6.44 , 5.074]]],
       [[[4.3  , 5.902, 4.508]]],
       [[[3.248, 5.695, 6.33 ]]],
       [[[5.797, 4.73 , 5.2  ]]],
       [[[4.33 , 5.008, 6.555]]],
       [[[3.5  , 5.785, 5.82 ]]],
       [[[3.848, 4.887, 3.334]]],
       [[[3.717, 4.5  , 4.82 ]]],
       [[[5.363, 3.738, 6.176]]],
       [[[6.273, 4.883, 4.812]]],
       [[[6.695, 4.56 , 4.03 ]]],
       [[[4.473, 3.453, 4.46 ]]],
       [[[6.77 , 3.955, 6.086]]],
       [[[3.13 , 3.904, 5.32 ]]],
       [[[4.156, 5.105, 5.62 ]]],
       [[[4.09 , 4.43 , 4.812]]],
       [[[3.023, 6.56 , 5.94 ]]],
       [[[3.857, 6.996, 3.9  ]]],
       [[[4.633, 6.7  , 5.57 ]]],
       [[[3.709, 3.201, 5.758]]]], dtype=np.float16)
input_dict_0 = {'v6_0':input_data_0, 'v2_0':input_data_1}
inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_0.items())
torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v6_0', 'v2_0'], output_names=['v1_0', 'v3_0'], opset_version=14, do_constant_folding=False)

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        getitem_1 = _args[1];  _args = None
        ceil = torch.ceil(getitem)
        abs_1 = torch.abs(ceil)
        max_1 = torch.max(getitem_1, abs_1)
        to = ceil.to(dtype = torch.int32)
        return (max_1, to)

model_1 = Model1()
input_dict_1 = {'v0_0':input_data_0, 'v3_0':input_data_1}
inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_1.items())
torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0', 'v3_0'], output_names=['v5_0', 'v7_0'], opset_version=14, do_constant_folding=False)

output_names_0 = ['v1_0', 'v3_0']
sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

output_names_1 = ['v5_0', 'v7_0']
sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))
output_name_dict = {'v3_0': 'v5_0', 'v1_0': 'v7_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_enable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_enable_opt triggers assertion")
    print(e)
print('=========================')

output_names_0 = ['v1_0', 'v3_0']
sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

output_names_1 = ['v5_0', 'v7_0']
sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_disable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_disable_opt triggers assertion")
    print(e)
print('=========================')

onnx_model_0 = onnx.load('0.onnx')
onnx_model_outputs_0 = [node.name for node in onnx_model_0.graph.output]
shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()}
mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, freeze_params=True)
with tvm.transform.PassContext(opt_level=0):
    executor_0 = relay.build_module.create_executor("graph", mod_0, tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
    executor_res_0 = [tensor.numpy() for tensor in executor_0(**input_dict_0)]
    output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))

onnx_model_1 = onnx.load('1.onnx')
onnx_model_outputs_1 = [node.name for node in onnx_model_1.graph.output]
shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()}
mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, freeze_params=True)
with tvm.transform.PassContext(opt_level=0):
    executor_1 = relay.build_module.create_executor("graph", mod_1, tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
    executor_res_1 = [tensor.numpy() for tensor in executor_1(**input_dict_1)]
    output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
output_name_dict = {'v3_0': 'v5_0', 'v1_0': 'v7_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("tvm_opt_4 does not trigger assertion")
except AssertionError as e:
    print("tvm_opt_4 triggers assertion")
    print(e)
print('=========================')

@hariharans29 hariharans29 added core runtime issues related to core runtime and removed ep:tvm issues related to TVM execution provider labels Oct 26, 2023
@Azyka Azyka changed the title Models with multiple outputs produce incorrect results when handling fp16 data Models with multiple outputs produce different results when the order of irrelevant lines are changed Nov 1, 2023
@github-actions github-actions bot added the ep:tvm issues related to TVM execution provider label Nov 1, 2023
Copy link
Contributor

github-actions bot commented Jan 4, 2024

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime ep:tvm issues related to TVM execution provider stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

2 participants