-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Significant output mismatch on different GPU #19288
Comments
does this happen with the latest production TensorRT version? (TensorRT 8.6.1) ? |
I haven't tested yet but considering it happens with CUDAExecutionProvider as well I suspect it will. I can try updating my Dockerfile base image to a newer version. |
You can try set an environment variable In my test machine, it helps reduce the difference.
|
I think that did it! Output now: 0.00019308803894091398 Thank you so much! I'm going to test with my original C++ program to see if that resolves it. |
TensorRT 8.6.1 also has this in the release notes. "For some networks, containing matrix multiplication operations on A100, using TF32 could cause accuracy degradation. Disabling TF32 was the workaround. This issue has been fixed in this release." |
Thank you so much @tianleiwu and @jywu-msft! I had been stuck on this for about 2 weeks. I should have asked sooner. Setting |
[TF32](https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/) could help boost performance on GPU of SM >= 80. Sometime, user observes accuracy loss, or need disable TF32 for testing purpose. To disable TF32, it is also possible to set environment variable `NVIDIA_TF32_OVERRIDE = 0`. However, sometime we do not want to use environment variable to avoid impacting other applications, or want to have finer control (like one session using TF32, and another session not). This provider option could help. Here we add a provider option `use_tf32`. When `use_tf32 = 0`, we will disable TF32 for float MatMul/GEMM in cublas. It applies to MatMulNBits, Attention, LongformerAttention, PackedAttention, PackedMultiHeadAttention operators when float GEMM is used internally in the operator. Note that it will not impact other data type, like fp8 gemm could still use TF32 in accumulation. Previously, cublasGemmStridedBatchedHelper does not use TF32 in inference. Here we enabled TF32 by default, so we might observe speed up for FP32 transformers models on SM >= 80. There is another PR that enables the option for cuDNN Conv later. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> #15407 #19288
Describe the issue
Significant (Up to 165% difference) ONNX model output mismatch of when using ada series GPU (Compute capability 8.6 and 8.9) versus older GPUs (tested compute capability 6.1 and 7.5). Occurs with TensorRT execution (tested in C++ with TRT 8.5.3.1 and in python using onnxruntime with TRT 8.5.1 both with CUDA 11.8) and CUDA execution (using onnxruntime CUDAExecutionProvider). Using onnxruntime, I also ran model using CPUExecutionProvider to get a baseline. CPUExecutionProvider outputs match on both computers and show similar difference from the TensorRT or CUDA execution. This was discovered on internal models but I was able to reproduce with model from huggingface.
Specs:
Ada series GPU - RTX 2000 Ada Generation Laptop GPU w/ 8GB of memory compute capability 8.9
(Also appeared to occur but not thoroughly tested on A10 card on remove server, details not known)
Older GPUs - Quadro P4000 8GB w/ 8GB compute capability 6.1
T1000 w/ 8GB compute capability 7.5
To reproduce
Used model found here: https://huggingface.co/microsoft/resnet-18/blob/refs%2Fpr%2F3/onnx/model.onnx
I ran this script:
Inside a docker image created with this dockerfile:
On the machine with the new GPU the output is:
Max: 1.6485868692398071 (165% difference)
Mean: 0.006287568248808384
Median: 0.0008300603367388248
On the machine with the old GPUs (same exact output with each GPU):
Max: 0.00023340669577009976 (0.02% difference)
Mean: 2.1887146886001574e-06
Median: 3.391258474039205e-07
CPU outputs match on both machines.
Urgency
Currently blocking deployment in region cloud provider does not offer older GPUs.
Platform
Linux
OS Version
Ubuntu 20.04 (running in docker)
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
onnxruntime-gpu==1.15.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU, CUDA, TensorRT
Execution Provider Library Version
CUDA 11.8 and TensorRT 8.5
Model File
https://huggingface.co/microsoft/resnet-18/blob/refs%2Fpr%2F3/onnx/model.onnx
Is this a quantized model?
Unknown
The text was updated successfully, but these errors were encountered: