You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I am experiencing an issue where the ONNX Runtime produces NaN outputs for the atan2 operation, specifically when the input x and y is set to zero. This behavior is inconsistent with PyTorch's output, which correctly handles the division by zero scenario in the atan2 function. Thanks!
To reproduce
Create a simple PyTorch model using the torch.atan2 function.
Export this model to ONNX format.
Run inference using both PyTorch and ONNX Runtime on inputs where x is a tensor of zeros and y is a zero tensor.
Observe that PyTorch correctly handles the operation and returns zero values, while the ONNX Runtime produces NaN.
import torch
import torch.nn as nn
import onnxruntime as ort
class ATan2Model(nn.Module):
def __init__(self):
super(ATan2Model, self).__init__()
def forward(self, x, y):
return torch.atan2(y, x)
if __name__ == '__main__':
model = ATan2Model()
x = torch.zeros(4, dtype=torch.float)
y = torch.zeros(4, dtype=torch.float)
model_output = model(x, y)
input_names = ["x", "y"]
output_names = ["output"]
onnx_filename = "atan2_model.onnx"
torch.onnx.export(model, (x, y), onnx_filename, verbose=False,
input_names=input_names, output_names=output_names,
export_params=True, opset_version=16,
keep_initializers_as_inputs=True, do_constant_folding=True)
onnx_model = ort.InferenceSession(onnx_filename, providers=['CPUExecutionProvider'])
inputs = {"x": x.numpy(), "y": y.numpy()}
onnx_output = onnx_model.run(None, inputs)
print("PyTorch Model Output:", model_output.numpy())
print("ONNX Model Output:", onnx_output)
Urgency
No response
Platform
Linux
OS Version
unbuntu20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
onnxruntime==1.16.3
ONNX Runtime API
Python
Architecture
X86
Execution Provider
Default CPU
Execution Provider Library Version
CUDA==12.0
The text was updated successfully, but these errors were encountered:
I think NaN is the correct value per math definition. From PyTorch's document, atan2(x,y) computes arctan(x/y). As x=0 and y=0 makes x/y undefined, the output of arctan is undefined.
@wschin
Thanks for your reply. std::atan2 in C++ is defined to handle more boundary conditions compared to the standard arctan(or std::atan). Specifically, std::atan2 is designed to compute the arctangent of two variables, y and x, effectively computing arctan(y/x). However, unlike a direct division, atan2 is well-defined for cases where y = 0 and x = 0, as well as for other edge cases where the division y/x would be undefined or result in infinity.
In the context of PyTorch's implementation, as seen in the source code at PyTorch GitHub, std::atan2 is used. This choice aligns with the function's ability to handle various quadrant signs and boundary conditions more robustly than a simple arctan.
When both x and y are zero, the mathematical definition of arctan(x/y) becomes undefined because it involves division by zero. However, std::atan2(0, 0) in C++ is defined to return a specific value (commonly zero) rather than being undefined.
Describe the issue
Hello, I am experiencing an issue where the ONNX Runtime produces NaN outputs for the atan2 operation, specifically when the input x and y is set to zero. This behavior is inconsistent with PyTorch's output, which correctly handles the division by zero scenario in the atan2 function. Thanks!
To reproduce
Urgency
No response
Platform
Linux
OS Version
unbuntu20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
onnxruntime==1.16.3
ONNX Runtime API
Python
Architecture
X86
Execution Provider
Default CPU
Execution Provider Library Version
CUDA==12.0
The text was updated successfully, but these errors were encountered: