Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output mismatch of duplicate torch.Tensor.to nodes after optimization #18211

Open
Azyka opened this issue Nov 1, 2023 · 4 comments
Open

Output mismatch of duplicate torch.Tensor.to nodes after optimization #18211

Azyka opened this issue Nov 1, 2023 · 4 comments
Labels
converter related to ONNX converters stale issues that have not been addressed in a while; categorized by a bot

Comments

@Azyka
Copy link

Azyka commented Nov 1, 2023

Describe the issue

ONNX opset version: 14
When 2 duplicate nodes(which have the same inputs and outputs) of torch.Tensor.to are defined, the model produce wrong results after ort optimization.

To reproduce

Model code:

import onnxruntime as ort
import onnx
import numpy as np
from numpy import testing
import torch

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        tan = torch.tan(getitem)
        to = tan.to(dtype = torch.int64)
        to_e = tan.to(dtype = torch.int64)
        to_1 = to_e.to(dtype = torch.bool)
        return (to, to_1)

model_0 = Model0()
output_names_0 = ['v4_0', 'v3_0']
input_data_0 = np.array(3.645, dtype=np.float32)
input_dict_0 = {'v5_0':input_data_0}
inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_0.items())
torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v5_0'], output_names=output_names_0, opset_version=14, do_constant_folding=False)

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        tan = torch.tan(getitem)
        to = tan.to(dtype = torch.int64)
        to_1 = to.to(dtype = torch.bool)
        to_e = tan.to(dtype = torch.int64)
        return (to_e, to_1)

model_1 = Model1()
output_names_1 = ['v5_0', 'v7_0']
input_dict_1 = {'v0_0':input_data_0}
inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_1.items())
torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0'], output_names=output_names_1, opset_version=14, do_constant_folding=False)

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))
output_name_dict = {'v4_0': 'v5_0', 'v3_0': 'v7_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        print(tensor_name_0, tensor_name_1)
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_enable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_enable_opt triggers assertion")
    print(e)
print('=========================')

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_disable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_disable_opt triggers assertion")
    print(e)
print('=========================')

Output:

=========================
onnxruntime_enable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
 x: array(True)
 y: array(False)
=========================
=========================
onnxruntime_disable_opt does not trigger assertion
=========================

Model0 produces array(True) after optimization, which is supposed to be array(False).

Urgency

This is an incorrect functionality implementation. It may cause severe bugs for those systems on the top of ORT.

Platform

Linux

OS Version

Ubuntu 22.04.3 LTS (x86_64)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@Azyka Azyka changed the title Output Mismatch of duplicate torch.Tensor.to nodes after optimization Output mismatch of duplicate torch.Tensor.to nodes after optimization Nov 1, 2023
@carzh carzh added the core runtime issues related to core runtime label Nov 1, 2023
Copy link
Contributor

github-actions bot commented Dec 4, 2023

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Dec 4, 2023
@thiagocrepaldi thiagocrepaldi added converter related to ONNX converters and removed core runtime issues related to core runtime stale issues that have not been addressed in a while; categorized by a bot labels Dec 4, 2023
@thiagocrepaldi
Copy link
Contributor

@Azyka this might be fixed after pytorch/pytorch#96320 if you set torch.onnx.export(..., keep_initializers_as_inputs=True

Try it out and let us know how it goes

@Azyka
Copy link
Author

Azyka commented Dec 5, 2023

@Azyka this might be fixed after pytorch/pytorch#96320 if you set torch.onnx.export(..., keep_initializers_as_inputs=True

Try it out and let us know how it goes

@thiagocrepaldi I tried the keep_initializers_as_inputs=True with the latest torch version, and the error still exists.

=========================
onnxruntime_enable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
 x: array(True)
 y: array(False)
=========================
=========================
onnxruntime_disable_opt does not trigger assertion
=========================

Copy link
Contributor

github-actions bot commented Jan 4, 2024

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
converter related to ONNX converters stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

3 participants