-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SkipLayerNormFusion -- High Output Difference Between PyTorch and ONNX Runtime with Extended Optimizations #17689
Comments
This might take some time to track down. One quick way to continue investtigating is try disabling optimizers in the EXTENDED suite (see comment here - #17476 (comment)) and narrow it down to an optimizer that is causing the problem. |
@hariharans29 Indeed, for this kind of faulty-optimizer debugging it would be useful to get a list of enabled optimizer names per a given Can this |
|
@vadimkantorov / @WoodieDudy -
Same ask here - #17476. No API is available currently to retrieve the list of optimizers. We will need to prirotize the request.
See example usage here
Extended optimizers or Level 2 optimizers (which is where you are seeing the bug manifest) should be listed here. Based on usage example above, to disable fusion of Gemm + Activation use "GemmActivationFusion". There seems to be a check to ensure that the optimizers that are being disabled are available in the build - . Keep in mind that some Level 2 optimizers may not be applicable to the CUDA EP. For example, this is only compatible with the CPU EP. I understand this is not the most ideal way to debug this but is the best way I can think of to get started with the debugging. |
I feel there should be a flag such as |
@hariharans29 It turned out that disabling sess_options = onnxruntime.SessionOptions()
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
self.onnxruntime_session = onnxruntime.InferenceSession(
onnx_model_path,
providers=providers,
sess_options=sess_options,
disabled_optimizers=[
# 'QDQS8ToU8Transformer',
# 'QDQSelectorActionTransformer',
# 'ConvActivationFusion',
# 'GeluFusion',
# 'LayerNormFusion',
# 'SimplifiedLayerNormFusion',
# 'AttentionFusion',
# 'EmbedLayerNormFusion',
# 'GatherToSplitFusion',
# 'GatherToSliceFusion',
# 'MatmulTransposeFusion',
# 'BiasGeluFusion',
# 'FastGeluFusion',
# 'GeluApproximation'
'SkipLayerNormFusion',
# 'QuickGeluFusion',
]
) |
Thanks for the investigation thus far. I guess we have some issue in the SLN fusion logic and/or SLN kernel implementation. CC: @tianleiwu. |
@hariharans29, anyone you'd recommend to fix this? Please feel free to assign this issue. This looks like an important operator. |
@wschin - I am unable to recommend anyone to fix this. Maybe @tianleiwu can comment. Also from OP's comment - This bug is only reproducible on a trained model, but I can't post it. But here example of random initialized model. Maybe it can help to understand precision casts inside Not sure if the posted model repros the SLN issue as the original issue was posted with the following comment ( |
I looked at bugmodel.onnx, the SkipLayerNormalization only involves 4 inputs (input, skip, weight and bias). input shape is 16x4x512. There is no data Cast happens. The pattern is normal so it is not an issue of graph fusion. I could do some test with random inputs to see whether SkipLayerNormalization kernel has precision issue. Update: evaluate BERT model on SQUAD data shows no regression on end-to-end accuracy. The SkipLayerNormalization uses fp16 in some part of accumulation. In theory, it might drop precision for some scenario. I suggest to evaluate end-to-end accuracy, then decide whether need to disable the fusion or fall back to fp32 accumulation using |
@WoodieDudy, could you try add a flag {"enable_skip_layer_norm_strict_mode": True} to provider option. Example usage: onnxruntime/onnxruntime/python/tools/transformers/convert_generation.py Lines 671 to 677 in a441a71
Let me know whether that could reduce output difference. |
@tianleiwu @hariharans29 Model in repro in fp16. There is an assumption that recalculations of weights for fusion are done in fp16 and because of this you get a bad result. It is better to cast in fp32 first, then do fusion, then back to fp16. btw {"enable_skip_layer_norm_strict_mode": True} solves the problem too |
It is likely caused by SkipLayerNorm kernel using fp16 in some part of accumulation. For BERT like model, the accuracy is acceptable (for example, SQuAD use case). {"enable_skip_layer_norm_strict_mode": True} will force to use fp32 in accumulation, but that will get worse performance. It might need for some decoder model. There is trade off between performance and accuracy there. Need evaluate your end-to-end accuracy to pick one. |
@tianleiwu is there any weight recomputation happening during fusion? Is it by chance done in fp16? If your hypothesis is true, why is fusion changing things so much? Why enabling fusion changes the acc dtype? Shouldn't accumulation in LayerNorm always happen in fp32? |
@vadimkantorov, there is NO weight recomputation during fusion. See the source code: The acc dtype is hard coded in cuda kernel and it is not related to fusion. For example, the layer normalization kernel uses T everywhere (that means it uses T=fp16 during some part of accumulation):
|
Does it mean then that acc dtype / computation precision is actually different in regular LayerNorm and if SkipLayerNormFusion is enabled, right? (model weights and inputs are fp16 for both cases) |
I have a trained PyTorch model, and when I export it to onnx with
static shapes
infp16
and with defaultall optimizations enabled
by ort.GraphOptimizationLevel.ORT_ENABLE_ALL.When I compare the network outputs between PyTorch and ONNX Runtime, I observe a significant difference in the output tensors:
I get small difference:
optimization level to basic
dynamic_shapes
andGraphOptimizationLevel.ORT_ENABLE_All
fp32
For cases above outputs diff looks like this:
I think this is a bug in the extended optimization fuses of ORT or fp16<>fp32 casts.
This bug is only reproducible on a trained model, but I can't post it. But here example of random initialized model. Maybe it can help to understand precision casts inside
Thanks
Not urgent
Platform
Linux
OS Version
Ubuntu 22.04
ONNX Runtime Version
onnxruntime-gpu==1.16.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA Version: 12.1
The text was updated successfully, but these errors were encountered: