Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SkipLayerNormFusion -- High Output Difference Between PyTorch and ONNX Runtime with Extended Optimizations #17689

Open
WoodieDudy opened this issue Sep 25, 2023 · 16 comments
Labels
core runtime issues related to core runtime ep:CUDA issues related to the CUDA execution provider

Comments

@WoodieDudy
Copy link

WoodieDudy commented Sep 25, 2023

I have a trained PyTorch model, and when I export it to onnx with static shapes in fp16 and with default all optimizations enabled by ort.GraphOptimizationLevel.ORT_ENABLE_ALL.
When I compare the network outputs between PyTorch and ONNX Runtime, I observe a significant difference in the output tensors:

  • Max diff: 9.7656
  • Mean diff: 1.6991

I get small difference:

  1. When I set optimization level to basic
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
ort_session = ort.InferenceSession(f"{model_name}.onnx", providers=providers, sess_options=sess_options)
  1. When I export with dynamic_shapes and GraphOptimizationLevel.ORT_ENABLE_All
  2. With just fp32

For cases above outputs diff looks like this:

  • Max diff: 0.2656
  • Mean diff: 0.0305

I think this is a bug in the extended optimization fuses of ORT or fp16<>fp32 casts.
This bug is only reproducible on a trained model, but I can't post it. But here example of random initialized model. Maybe it can help to understand precision casts inside

Thanks

Not urgent

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Version

onnxruntime-gpu==1.16.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA Version: 12.1

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Sep 25, 2023
@hariharans29
Copy link
Member

hariharans29 commented Sep 25, 2023

  1. So is the diff is low (side question - is low == acceptable or just lower diff than the "bad" case?) when using either fp32 or exporting with dynamic shapes ?

  2. I am guessing there is a bug with fp16 extended optimization and static shapes probably induces a buggy optimizer into action which would otherwise not be in play because the shape is not known statically.

This might take some time to track down. One quick way to continue investtigating is try disabling optimizers in the EXTENDED suite (see comment here - #17476 (comment)) and narrow it down to an optimizer that is causing the problem.

@vadimkantorov
Copy link

vadimkantorov commented Sep 26, 2023

@hariharans29 Indeed, for this kind of faulty-optimizer debugging it would be useful to get a list of enabled optimizer names per a given opt_level / graph_optimization_level. Otherwise it requires some code digging to feed disabled_optimizers name list...

Can this disabled_optimizers be fed directly into the session options?

@WoodieDudy
Copy link
Author

  1. low diff is acceptable
  2. Can you please provide a list of optimizers that need to be disabled

@hariharans29
Copy link
Member

hariharans29 commented Sep 28, 2023

@vadimkantorov / @WoodieDudy -

  1. ....it would be useful to get a list of enabled optimizer names per a given opt_level / graph_optimization_level

Same ask here - #17476. No API is available currently to retrieve the list of optimizers. We will need to prirotize the request.

  1. Can this disabled_optimizers be fed directly into the session options?

See example usage here

  1. Can you please provide a list of optimizers that need to be disabled

Extended optimizers or Level 2 optimizers (which is where you are seeing the bug manifest) should be listed here. Based on usage example above, to disable fusion of Gemm + Activation use "GemmActivationFusion". There seems to be a check to ensure that the optimizers that are being disabled are available in the build - . Keep in mind that some Level 2 optimizers may not be applicable to the CUDA EP. For example, this is only compatible with the CPU EP.

I understand this is not the most ideal way to debug this but is the best way I can think of to get started with the debugging.

@wschin
Copy link
Contributor

wschin commented Oct 2, 2023

I feel there should be a flag such as DUMP_GRAPH_BEFORE_AND_AFTER_TRANSFORM to automatically get the graph before and after each applied transforms.

@wschin wschin added the core runtime issues related to core runtime label Oct 2, 2023
@WoodieDudy
Copy link
Author

@hariharans29
I took a list of optimizers from here, ran several experiments to find out which optimizers turn off reduces the error

It turned out that disabling SkipLayerNormFusion reduces the error by such values:
max difference: 9.7656 -> 0.3906
mean difference: 1.6991 -> 0.0758

sess_options = onnxruntime.SessionOptions()
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
self.onnxruntime_session = onnxruntime.InferenceSession(
    onnx_model_path,
    providers=providers,
    sess_options=sess_options,
    disabled_optimizers=[
        # 'QDQS8ToU8Transformer',
        # 'QDQSelectorActionTransformer',
        # 'ConvActivationFusion',
        # 'GeluFusion',
        # 'LayerNormFusion',
        # 'SimplifiedLayerNormFusion',
        # 'AttentionFusion',
        # 'EmbedLayerNormFusion',
        # 'GatherToSplitFusion',
        # 'GatherToSliceFusion',
        # 'MatmulTransposeFusion',
        # 'BiasGeluFusion',
        # 'FastGeluFusion',
        # 'GeluApproximation'
        'SkipLayerNormFusion',
        # 'QuickGeluFusion',
    ]
)

@hariharans29
Copy link
Member

hariharans29 commented Oct 4, 2023

@hariharans29 I took a list of optimizers from here, ran several experiments to find out which optimizers turn off reduces the error

It turned out that disabling SkipLayerNormFusion reduces the error by such values: max difference: 9.7656 -> 0.3906 mean difference: 1.6991 -> 0.0758

sess_options = onnxruntime.SessionOptions()
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
self.onnxruntime_session = onnxruntime.InferenceSession(
    onnx_model_path,
    providers=providers,
    sess_options=sess_options,
    disabled_optimizers=[
        # 'QDQS8ToU8Transformer',
        # 'QDQSelectorActionTransformer',
        # 'ConvActivationFusion',
        # 'GeluFusion',
        # 'LayerNormFusion',
        # 'SimplifiedLayerNormFusion',
        # 'AttentionFusion',
        # 'EmbedLayerNormFusion',
        # 'GatherToSplitFusion',
        # 'GatherToSliceFusion',
        # 'MatmulTransposeFusion',
        # 'BiasGeluFusion',
        # 'FastGeluFusion',
        # 'GeluApproximation'
        'SkipLayerNormFusion',
        # 'QuickGeluFusion',
    ]
)

Thanks for the investigation thus far. I guess we have some issue in the SLN fusion logic and/or SLN kernel implementation. CC: @tianleiwu.

@wschin
Copy link
Contributor

wschin commented Oct 6, 2023

@hariharans29, anyone you'd recommend to fix this? Please feel free to assign this issue. This looks like an important operator.

@wschin wschin changed the title High Output Difference Between PyTorch and ONNX Runtime with Extended Optimizations SkipLayerNormFusion -- High Output Difference Between PyTorch and ONNX Runtime with Extended Optimizations Oct 6, 2023
@hariharans29
Copy link
Member

@wschin - I am unable to recommend anyone to fix this. Maybe @tianleiwu can comment.

Also from OP's comment - This bug is only reproducible on a trained model, but I can't post it. But here example of random initialized model. Maybe it can help to understand precision casts inside

Not sure if the posted model repros the SLN issue as the original issue was posted with the following comment (I think this is a bug in the extended optimization fuses of ORT or fp16<>fp32 casts.). @WoodieDudy can confirm if the posted model repros the issue. It would be easier for the investigator if the model can be trimmed down to a minimal repro than such a large model.

@tianleiwu
Copy link
Contributor

tianleiwu commented Oct 9, 2023

I looked at bugmodel.onnx, the SkipLayerNormalization only involves 4 inputs (input, skip, weight and bias). input shape is 16x4x512. There is no data Cast happens. The pattern is normal so it is not an issue of graph fusion.

I could do some test with random inputs to see whether SkipLayerNormalization kernel has precision issue.

Update: evaluate BERT model on SQUAD data shows no regression on end-to-end accuracy.

The SkipLayerNormalization uses fp16 in some part of accumulation. In theory, it might drop precision for some scenario. I suggest to evaluate end-to-end accuracy, then decide whether need to disable the fusion or fall back to fp32 accumulation using enable_skip_layer_norm_strict_mode.

@tianleiwu
Copy link
Contributor

@WoodieDudy, could you try add a flag {"enable_skip_layer_norm_strict_mode": True} to provider option. Example usage:

cuda_provider_options = {"enable_skip_layer_norm_strict_mode": True}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
execution_providers = [
(name, provider_options[name]) if name in provider_options else name for name in execution_providers
]
ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)

Let me know whether that could reduce output difference.

@WoodieDudy
Copy link
Author

WoodieDudy commented Oct 20, 2023

@tianleiwu @hariharans29
Sorry it's taken me a while to respond
Here's code for repro.
There's also a link to weights on which the bug is reproduced

Model in repro in fp16. There is an assumption that recalculations of weights for fusion are done in fp16 and because of this you get a bad result. It is better to cast in fp32 first, then do fusion, then back to fp16.

btw {"enable_skip_layer_norm_strict_mode": True} solves the problem too

@tianleiwu
Copy link
Contributor

tianleiwu commented Oct 20, 2023

It is likely caused by SkipLayerNorm kernel using fp16 in some part of accumulation. For BERT like model, the accuracy is acceptable (for example, SQuAD use case).

{"enable_skip_layer_norm_strict_mode": True} will force to use fp32 in accumulation, but that will get worse performance. It might need for some decoder model.

There is trade off between performance and accuracy there. Need evaluate your end-to-end accuracy to pick one.

@vadimkantorov
Copy link

vadimkantorov commented Oct 20, 2023

@tianleiwu is there any weight recomputation happening during fusion? Is it by chance done in fp16?

If your hypothesis is true, why is fusion changing things so much? Why enabling fusion changes the acc dtype?

Shouldn't accumulation in LayerNorm always happen in fp32?

@tianleiwu
Copy link
Contributor

tianleiwu commented Oct 20, 2023

@vadimkantorov, there is NO weight recomputation during fusion. See the source code:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc

The acc dtype is hard coded in cuda kernel and it is not related to fusion. For example, the layer normalization kernel uses T everywhere (that means it uses T=fp16 during some part of accumulation):

__device__ inline void LayerNorm(

@vadimkantorov
Copy link

vadimkantorov commented Oct 20, 2023

Does it mean then that acc dtype / computation precision is actually different in regular LayerNorm and if SkipLayerNormFusion is enabled, right? (model weights and inputs are fp16 for both cases)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

5 participants