Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer fails on shape inference error over native_batch_norm #1443

Open
xadupre opened this issue Apr 25, 2024 · 19 comments · Fixed by #1472
Open

Optimizer fails on shape inference error over native_batch_norm #1443

xadupre opened this issue Apr 25, 2024 · 19 comments · Fixed by #1472
Assignees
Labels
bug Something isn't working topic: rewriter

Comments

@xadupre
Copy link
Member

xadupre commented Apr 25, 2024

The optimizer fails for the attach model (so dort fails as well). It was obtained with the latest onnx, onnxscript and torch nightly.

dump3bug.zip

To replicate:

import onnx
from onnxscript import optimizer
onx = onnx.load(model)
optimized = optimizer.optimize(onx)

It is coming from the following graph module.

graph():
    %primals_7 : [num_users=1] = placeholder[target=primals_7]
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=1] = placeholder[target=primals_5]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_7, %primals_1), kwargs = {})
    %_native_batch_norm_legit_no_training : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%add, %primals_2, %primals_3, %primals_4, %primals_5, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_no_training, 0), kwargs = {})
    return (add, getitem)

Error:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "check_model.py", line 43, in <module>
    optimized = optimizer.optimize(onx)
  File "onnxscript/onnxscript/optimizer/__init__.py", line 61, in optimize
    model = onnx.shape_inference.infer_shapes(
  File "onnx/onnx/shape_inference.py", line 46, in infer_shapes
    inferred_model_str = C.infer_shapes(
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:_aten_native_batch_norm_inference_onnx, node name: _aten_native_batch_norm_inference_onnx_2): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (2) vs (0)
@justinchuby
Copy link
Collaborator

@justinchuby justinchuby added topic: torch_lib Related to the torch/aten function lib in development bug Something isn't working labels Apr 25, 2024
@titaiwangms
Copy link
Contributor

I can take a look after I done my work on hands,

@justinchuby justinchuby added topic: rewriter and removed topic: torch_lib Related to the torch/aten function lib in development labels Apr 25, 2024
@titaiwangms
Copy link
Contributor

titaiwangms commented Apr 25, 2024

@xadupre I think optimizer should be already applied in torch-nightly? https://github.com/pytorch/pytorch/blob/d5182bb75bbc109cb327212e7205981fbf72cb5e/torch/onnx/_internal/exporter.py#L1274

Are you writing new optimization? Just trying to understand the usage here.

Specifically, if we try:

import onnx
from onnxscript import optimizer
from onnxscript.rewriter import onnxruntime as ort_rewriter

onx = onnx.load("dump3bug.onnx")
onnx.checker.check_model(onx, full_check=True)
optimized = optimizer.optimize(onx)

The same error is spotted by checker.

@xadupre
Copy link
Member Author

xadupre commented Apr 25, 2024

It does. I tweaked torch code in onnxruntime.py to get the model before it gets optimized to know of the error happens before optimization or after. It is after.

@titaiwangms
Copy link
Contributor

Could you update the model to the one before optimizer?

@xadupre
Copy link
Member Author

xadupre commented Apr 25, 2024

I'll check again but it should be the one before optimizer.

@titaiwangms
Copy link
Contributor

I mean the model in zip doesn't pass onnx.checker.check_model(model, full_check=True). That's why it gets the error message from

model = onnx.shape_inference.infer_shapes(

It's not even hitting the constant folding and general rewriter yet it seems.

@titaiwangms
Copy link
Contributor

titaiwangms commented Apr 25, 2024

I wonder if we should put onnx.checker to guard the models generated from converter/dort. Or we already did?

@gramalingam
Copy link
Collaborator

I mean the model in zip doesn't pass onnx.checker.check_model(model, full_check=True).

True ... tried it, and this seems to fail

@xadupre
Copy link
Member Author

xadupre commented Apr 26, 2024

I would not call onnx.checker. The converter may introduce nodes coming from domain com.microsoft. I created PR #1467 to replicate the issue.

@titaiwangms
Copy link
Contributor

So I think there are two issues here. The first one is that if we don't want to make sure our models passing checker before feeding to optimizer, we should turn off strict_mode in ONNX shape type inference inside optimizer, since they are basically the same. I will submit a PR for this to unblock this model.

The other issue is that, in torchlib, we respect PyTorch native_batch_norm CUDA to accept size=0 outputs in index=1 and 2 (here), which is originated from PyTorch code. That's why in the error message saying the existing shape is 0. However, ONNX shape type inference infers this as 2. @justinchuby @xiaowuhu @gramalingam any suggestion on this?

@justinchuby
Copy link
Collaborator

Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember?

titaiwangms added a commit that referenced this issue Apr 26, 2024
Fix #1443 

In converter/dort, tensors retains their shape and type from PyTorch
models, and it saves us some efforts to infer them all like we did in
torchscript. However, when it comes to symbolic shapes, we still need
ONNX shape type inference. Error is raised when the inferred shape and
type are different from the carried ones. This is rare, but it happens
when a corner case is revealed. For example, in #1443, PyTorch generates
2 outputs with size=0 when native_batch_norm is run with CUDA.

This PR turn off the strict mode in ONNX shape type inference to avoid
crash in optimizer.
@titaiwangms
Copy link
Contributor

Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember?

The tests @xadupre I executed with CUDA, and repro the error. Could you point the code "passing shape inference". My guess is that one does not invoke strict mode.

@justinchuby
Copy link
Collaborator

I think it should be covered in the torchlib tests, but we don't run it with cuda regularly.

@justinchuby justinchuby reopened this Apr 26, 2024
@justinchuby justinchuby changed the title optimizer fails on this particular model Optimizer fails on shape inference error over native_batch_norm Apr 26, 2024
@gramalingam
Copy link
Collaborator

Hi, is this related to #1256 ?

@gramalingam
Copy link
Collaborator

Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember?

Given the comment in the code that Titai links above, it appears that cuda/cpu have different behavior? But the onnxscript encoding chooses one of the two behaviors (it says cuda) ... now, if the actual shapes are being emitted as produced by the runtime, there is going to be a mismatch between shape inferred by ONNX (the cuda shape) and the valueinfo shape embedded (coming from cpu) ... that would explain it, right?

@gramalingam
Copy link
Collaborator

But Titai also says the error is reproduced in a cuda run, which seems strange (inconsistent with the message here)

@titaiwangms
Copy link
Contributor

I guess we need to find out what happened in ONNX shape type inference. One can try this out with #1467 test cases, and turn #1472 strict mode back to True.

@titaiwangms
Copy link
Contributor

Write down some findings today:

This is only reproducible on DORT. Dynamo_export does not support this case, because it is decomposed at aot_autograd (Functionalization). And ExportedProgram can't repro this because the unused outputs are trimmed.

justinchuby pushed a commit that referenced this issue May 1, 2024
Fix #1443 

In converter/dort, tensors retains their shape and type from PyTorch
models, and it saves us some efforts to infer them all like we did in
torchscript. However, when it comes to symbolic shapes, we still need
ONNX shape type inference. Error is raised when the inferred shape and
type are different from the carried ones. This is rare, but it happens
when a corner case is revealed. For example, in #1443, PyTorch generates
2 outputs with size=0 when native_batch_norm is run with CUDA.

This PR turn off the strict mode in ONNX shape type inference to avoid
crash in optimizer.
@xadupre xadupre assigned xadupre and unassigned titaiwangms Jun 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working topic: rewriter
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants