-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimizer fails on shape inference error over native_batch_norm #1443
Comments
@titaiwangms Do you have an idea? Looks like it is related to https://github.com/microsoft/onnxscript/blame/0d98619dee85025f8fb110864607f6f477c3d8ae/onnxscript/function_libs/torch_lib/ops/core.py#L5625 |
I can take a look after I done my work on hands, |
@xadupre I think optimizer should be already applied in torch-nightly? https://github.com/pytorch/pytorch/blob/d5182bb75bbc109cb327212e7205981fbf72cb5e/torch/onnx/_internal/exporter.py#L1274 Are you writing new optimization? Just trying to understand the usage here. Specifically, if we try: import onnx
from onnxscript import optimizer
from onnxscript.rewriter import onnxruntime as ort_rewriter
onx = onnx.load("dump3bug.onnx")
onnx.checker.check_model(onx, full_check=True)
optimized = optimizer.optimize(onx) The same error is spotted by checker. |
It does. I tweaked torch code in onnxruntime.py to get the model before it gets optimized to know of the error happens before optimization or after. It is after. |
Could you update the model to the one before optimizer? |
I'll check again but it should be the one before optimizer. |
I mean the model in zip doesn't pass onnx.checker.check_model(model, full_check=True). That's why it gets the error message from
It's not even hitting the constant folding and general rewriter yet it seems. |
I wonder if we should put onnx.checker to guard the models generated from converter/dort. Or we already did? |
True ... tried it, and this seems to fail |
I would not call onnx.checker. The converter may introduce nodes coming from domain com.microsoft. I created PR #1467 to replicate the issue. |
So I think there are two issues here. The first one is that if we don't want to make sure our models passing checker before feeding to optimizer, we should turn off strict_mode in ONNX shape type inference inside optimizer, since they are basically the same. I will submit a PR for this to unblock this model. The other issue is that, in torchlib, we respect PyTorch native_batch_norm CUDA to accept size=0 outputs in index=1 and 2 (here), which is originated from PyTorch code. That's why in the error message saying the existing shape is 0. However, ONNX shape type inference infers this as 2. @justinchuby @xiaowuhu @gramalingam any suggestion on this? |
Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember? |
Fix #1443 In converter/dort, tensors retains their shape and type from PyTorch models, and it saves us some efforts to infer them all like we did in torchscript. However, when it comes to symbolic shapes, we still need ONNX shape type inference. Error is raised when the inferred shape and type are different from the carried ones. This is rare, but it happens when a corner case is revealed. For example, in #1443, PyTorch generates 2 outputs with size=0 when native_batch_norm is run with CUDA. This PR turn off the strict mode in ONNX shape type inference to avoid crash in optimizer.
The tests @xadupre I executed with CUDA, and repro the error. Could you point the code "passing shape inference". My guess is that one does not invoke strict mode. |
I think it should be covered in the torchlib tests, but we don't run it with cuda regularly. |
Hi, is this related to #1256 ? |
Given the comment in the code that Titai links above, it appears that cuda/cpu have different behavior? But the onnxscript encoding chooses one of the two behaviors (it says cuda) ... now, if the actual shapes are being emitted as produced by the runtime, there is going to be a mismatch between shape inferred by ONNX (the cuda shape) and the valueinfo shape embedded (coming from cpu) ... that would explain it, right? |
But Titai also says the error is reproduced in a cuda run, which seems strange (inconsistent with the message here) |
Write down some findings today: This is only reproducible on DORT. Dynamo_export does not support this case, because it is decomposed at aot_autograd (Functionalization). And ExportedProgram can't repro this because the unused outputs are trimmed. |
Fix #1443 In converter/dort, tensors retains their shape and type from PyTorch models, and it saves us some efforts to infer them all like we did in torchscript. However, when it comes to symbolic shapes, we still need ONNX shape type inference. Error is raised when the inferred shape and type are different from the carried ones. This is rare, but it happens when a corner case is revealed. For example, in #1443, PyTorch generates 2 outputs with size=0 when native_batch_norm is run with CUDA. This PR turn off the strict mode in ONNX shape type inference to avoid crash in optimizer.
The optimizer fails for the attach model (so dort fails as well). It was obtained with the latest onnx, onnxscript and torch nightly.
dump3bug.zip
To replicate:
It is coming from the following graph module.
Error:
The text was updated successfully, but these errors were encountered: