Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ShapeInferenceError] Inferred shape and existing shape differ in dimension 0 #1467

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions tests/dort_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import copy
import inspect
import itertools
import sys
import unittest
import warnings
from typing import Optional

import onnxruntime
import torch
import torch.nn as nn
import torch.onnx
from onnx.inliner import inline_local_functions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from onnx.inliner import inline_local_functions
import onnx.inliner



def make_aot_ort(
dynamic: bool = False,
ort_optimization_level: Optional[str] = None,
) -> tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> tuple:
) -> tuple:

Is the return type accurate?

code = inspect.getsource(
torch.onnx._internal.onnxruntime # pylint: disable=protected-access
)
if "optimizer.optimize" not in code:
raise unittest.SkipTest(
f"torch=={torch.__version__!r} is not recent enough, "
f"file {torch.onnx._internal.onnxruntime.__file__!r} " # pylint: disable=protected-access
f"does not optimize the exported model."
)

ort_session_options = onnxruntime.SessionOptions()
if ort_optimization_level is not None:
assert hasattr(onnxruntime.GraphOptimizationLevel, ort_optimization_level), (
f"Unexpected value {ort_optimization_level!r} for GraphOptimizationLevel, "
f"expecting one of the values in {dir(onnxruntime.GraphOptimizationLevel)}"
)
ort_session_options.graph_optimization_level = getattr(
onnxruntime.GraphOptimizationLevel, ort_optimization_level
)

export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic)

def inline_function(*args, **kwargs): # pylint: disable=unused-argument
first_model_proto = args[0]

next_model = inline_local_functions(first_model_proto)

del first_model_proto.graph.node[:]
del first_model_proto.functions[:]
del first_model_proto.graph.initializer[:]
del first_model_proto.opset_import[:]
first_model_proto.graph.node.extend(next_model.graph.node)
first_model_proto.functions.extend(next_model.functions)
first_model_proto.graph.initializer.extend(next_model.graph.initializer)
first_model_proto.opset_import.extend(next_model.opset_import)

return first_model_proto
Comment on lines +42 to +56
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid nested declarations of functions if it doesn't close over local variables?


options = torch.onnx._OrtBackendOptions( # pylint: disable=protected-access
export_options=export_options,
ort_session_options=ort_session_options,
pre_ort_model_transforms=[inline_function],
)

return torch.onnx._OrtBackend(options=options) # pylint: disable=protected-access


class FuncModule(torch.nn.Module):
def __init__(self, f, params=None):
if params is None:
params = ()
super().__init__()
self.f = f
self.ppp = torch.nn.Parameter(torch.Tensor([1]))
self.params = torch.nn.ParameterList(list(params))

def forward(self, *args):
f_args = list(itertools.chain(args, self.params))
f_args[0] = f_args[0] + self.ppp
res = self.f(*f_args)
return res


class FuncModuleModule(torch.nn.Module):
def __init__(self, f):
super().__init__()
self.f = f
self.mod = f
self.ppp = torch.nn.Parameter(torch.Tensor([1]))

def forward(self, *args):
x = args[0] + self.ppp
res = self.mod(x, *args[1:])
return res


class TestSimpleDort(unittest.TestCase):
def setUp(self):
super().setUp()
torch._dynamo.reset() # pylint: disable=protected-access
# hides the warnings
# torch.onnx.dynamo_export only implements opset version 18 for now.
# torch.library.impl_abstract was renamed to torch.library.register_fake.
warnings.simplefilter("ignore", (UserWarning, DeprecationWarning))

def assertONNX(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider a more descriptive name. Does it assert on that the output of the onnx model will match the torch output, for example? Also snake cases for all function names

self,
f,
args,
onnx_export: str,
params=None,
fullgraph: bool = True,
atol=1e-6,
rtol=1e-6,
dynamic_axes=None,
):
if sys.platform == "win32":
raise unittest.SkipTest("Windows not supported yet.")
assert isinstance(onnx_export, str), f"Export onnx is wrong for f={f}"
if isinstance(args, torch.Tensor):
args = [args]
if params is None:
params = ()
if isinstance(f, torch.nn.Module):
model = FuncModuleModule(f)
else:
model = FuncModule(f, params)
model.eval()

# forward/backward
local_aot_ort = make_aot_ort(dynamic=dynamic_axes is not None)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=dynamic_axes is not None,
fullgraph=fullgraph,
)

baseline_result = model(*args)
result = compiled_model(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xadupre I tried to put everything on CUDA and run the code, but even though compiled_model is on CUDA, this result will not be on CUDA. Do you know why?


if isinstance(baseline_result, tuple):
baseline_result = baseline_result[0]
result = result[0]
if isinstance(baseline_result, torch.Tensor):
torch.testing.assert_close(
baseline_result, result, atol=atol, rtol=rtol, equal_nan=True
)

baseline_result.sum().backward()
result.sum().backward()

l1 = list(model.parameters())
l2 = list(compiled_model.parameters())
self.assertEqual(len(l1), len(l2))
assert len(l1) > 0, "No gradient to test"
n_gradient = 0
for baseline_param, param in zip(l1, l2):
n_gradient += 1
torch.testing.assert_close(
baseline_param.grad,
param.grad,
atol=atol,
rtol=rtol,
equal_nan=True,
)
assert n_gradient > 0, "No gradient was checked"
else:
raise TypeError(f"Unexpected type {type(baseline_result)}.")

def test_acos(self):
# This test is just to make sure it is working.
x = torch.rand(3, 4, requires_grad=True)
self.assertONNX(
lambda x: x.acos(), x, onnx_export=inspect.currentframe().f_code.co_name
)

def test_batchnorm(self):
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.assertONNX(
nn.BatchNorm2d(2),
x,
onnx_export=inspect.currentframe().f_code.co_name,
)

def test_batchnorm_onnx_irv4(self):
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.assertONNX(
nn.BatchNorm2d(2), x, onnx_export=inspect.currentframe().f_code.co_name
)

def test_batchnorm_1d(self):
x = torch.ones(2, 2, requires_grad=True)
self.assertONNX(
nn.BatchNorm1d(2),
x,
onnx_export=inspect.currentframe().f_code.co_name,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading