Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Significant output mismatch on different GPU #19288

Closed
dhorkel opened this issue Jan 26, 2024 · 6 comments
Closed

[Performance] Significant output mismatch on different GPU #19288

dhorkel opened this issue Jan 26, 2024 · 6 comments
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. performance issues related to performance regressions

Comments

@dhorkel
Copy link

dhorkel commented Jan 26, 2024

Describe the issue

Significant (Up to 165% difference) ONNX model output mismatch of when using ada series GPU (Compute capability 8.6 and 8.9) versus older GPUs (tested compute capability 6.1 and 7.5). Occurs with TensorRT execution (tested in C++ with TRT 8.5.3.1 and in python using onnxruntime with TRT 8.5.1 both with CUDA 11.8) and CUDA execution (using onnxruntime CUDAExecutionProvider). Using onnxruntime, I also ran model using CPUExecutionProvider to get a baseline. CPUExecutionProvider outputs match on both computers and show similar difference from the TensorRT or CUDA execution. This was discovered on internal models but I was able to reproduce with model from huggingface.

Specs:
Ada series GPU - RTX 2000 Ada Generation Laptop GPU w/ 8GB of memory compute capability 8.9
(Also appeared to occur but not thoroughly tested on A10 card on remove server, details not known)

Older GPUs - Quadro P4000 8GB w/ 8GB compute capability 6.1
T1000 w/ 8GB compute capability 7.5

To reproduce

Used model found here: https://huggingface.co/microsoft/resnet-18/blob/refs%2Fpr%2F3/onnx/model.onnx

I ran this script:

import onnxruntime as rt
import numpy as np

onnx_model = "model.onnx"
ep_list = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']

sess = rt.InferenceSession(onnx_model ,providers=['CUDAExecutionProvider'])
sess2 = rt.InferenceSession(
    onnx_model,
    providers=[
        'CPUExecutionProvider'
    ]
)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
input_shape = sess.get_inputs()[0].shape
output_shape = sess.get_outputs()[0].shape

if input_shape[0] == 'batch_size':
    input_shape[0] = 1
if output_shape[0] == 'batch_size':
    output_shape[0] = 1

if input_shape[1] == 'num_channels':
    input_shape[1] = 3
if output_shape[1] == 'num_channels':
    output_shape[1] = 3

if input_shape[2] == 'height':
    input_shape[2] = 512

if input_shape[3] == 'width':
    input_shape[3] = 512



#print all input/outputs
print("Input name  :", input_name)
print("Input shape :", input_shape)
print("Output name :", output_name)
print("Output shape:", output_shape)


# random input
np.random.seed(0)
dummy_input = np.random.random(input_shape).astype(np.float32)

# run the inference
res = sess.run([output_name], {input_name: dummy_input})[0]
res2 = sess2.run([output_name], {input_name: dummy_input})[0]

# Max difference between the two outputs
res_ravel = res.ravel()
res2_ravel = res2.ravel()

# skip any where both are near zero
nonzero = np.logical_or(np.abs(res_ravel) > 1e-5, np.abs(res2_ravel) > 1e-5)
res_ravel = res_ravel[nonzero]
res2_ravel = res2_ravel[nonzero]

print(2*np.max(np.abs(res_ravel - res2_ravel)/np.abs(res_ravel+res2_ravel)))
print(2*np.mean(np.abs(res_ravel - res2_ravel)/np.abs(res_ravel+res2_ravel)))
print(2*np.median(np.abs(res_ravel - res2_ravel)/np.abs(res_ravel+res2_ravel)))

Inside a docker image created with this dockerfile:

ARG BASE_IMAGE=nvcr.io/nvidia/tensorrt:22.12-py3
FROM $BASE_IMAGE
SHELL ["/bin/bash", "-c"]

ARG INFERENCE_TYPE="trt"

ENV DEBIAN_FRONTEND=noninteractive

ENV PIP_ROOT_USER_ACTION=ignore

ENV PIP_NO_CACHE_DIR=off

RUN apt-get update && apt-get install -y dos2unix &&\
    apt-get install -y software-properties-common ffmpeg libsm6 libxext6    &&\
    add-apt-repository ppa:deadsnakes/ppa -y &&\
    apt-get install -y python3.10-dev python3.10-distutils &&\
    update-alternatives --install /usr/bin/python3 python /usr/bin/python3.10 1 &&\
    curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10

WORKDIR /usr/local/trt_testing/

RUN python3 -m pip install --upgrade pip && \
    python3 -m pip install onnxruntime-gpu==1.15.1 numpy==1.25.2


# Copy onnx files in
COPY model.onnx .
COPY onnxruntime_testing.py .

# Speed up CUDA module loading
ENV CUDA_MODULE_LOADING="LAZY"

CMD ["bash"]

On the machine with the new GPU the output is:
Max: 1.6485868692398071 (165% difference)
Mean: 0.006287568248808384
Median: 0.0008300603367388248

On the machine with the old GPUs (same exact output with each GPU):
Max: 0.00023340669577009976 (0.02% difference)
Mean: 2.1887146886001574e-06
Median: 3.391258474039205e-07

CPU outputs match on both machines.

Urgency

Currently blocking deployment in region cloud provider does not offer older GPUs.

Platform

Linux

OS Version

Ubuntu 20.04 (running in docker)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnxruntime-gpu==1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU, CUDA, TensorRT

Execution Provider Library Version

CUDA 11.8 and TensorRT 8.5

Model File

https://huggingface.co/microsoft/resnet-18/blob/refs%2Fpr%2F3/onnx/model.onnx

Is this a quantized model?

Unknown

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels Jan 26, 2024
@dhorkel dhorkel changed the title [Performance] [Performance] Significant output mismatch on different GPU Jan 26, 2024
@jywu-msft
Copy link
Member

does this happen with the latest production TensorRT version? (TensorRT 8.6.1) ?

@dhorkel
Copy link
Author

dhorkel commented Jan 26, 2024

I haven't tested yet but considering it happens with CUDAExecutionProvider as well I suspect it will. I can try updating my Dockerfile base image to a newer version.

@tianleiwu
Copy link
Contributor

tianleiwu commented Jan 26, 2024

You can try set an environment variable NVIDIA_TF32_OVERRIDE=0 to see whether it helps. It is because SM>=80 supports TF32. TF32 will increase mismatch. In PyTorch, there are settings like torch.backends.cudnn.allow_tf32 and torch.backends.cuda.matmul.allow_tf32. In OnnxRuntime, currently there is no similar configuration.

In my test machine, it helps reduce the difference.

D:\test>python test.py
Input name  : pixel_values
Input shape : [1, 3, 512, 512]
Output name : logits
Output shape: [1, 1000]
0.33340126276016235
0.0008839640067890286
0.00010457751341164112

D:\test>set NVIDIA_TF32_OVERRIDE=0
D:\test>python test.py
Input name  : pixel_values
Input shape : [1, 3, 512, 512]
Output name : logits
Output shape: [1, 1000]
0.0004496636101976037
2.7148862500325777e-06
3.557692593858519e-07

@dhorkel
Copy link
Author

dhorkel commented Jan 26, 2024

NVIDIA_TF32_OVERRIDE=0

I think that did it! Output now:

0.00019308803894091398
1.8716610838964698e-06
3.5120950769851333e-07

Thank you so much! I'm going to test with my original C++ program to see if that resolves it.

@jywu-msft
Copy link
Member

TensorRT 8.6.1 also has this in the release notes.

"For some networks, containing matrix multiplication operations on A100, using TF32 could cause accuracy degradation. Disabling TF32 was the workaround. This issue has been fixed in this release."

@dhorkel
Copy link
Author

dhorkel commented Jan 26, 2024

Thank you so much @tianleiwu and @jywu-msft! I had been stuck on this for about 2 weeks. I should have asked sooner.

Setting NVIDIA_TF32_OVERRIDE=0 fixed my problem both in onnxruntime and in my C++ code. I will also look into upgrading to TRT 8.6.

@dhorkel dhorkel closed this as completed Jan 26, 2024
tianleiwu added a commit that referenced this issue Feb 6, 2024
[TF32](https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/)
could help boost performance on GPU of SM >= 80. Sometime, user observes accuracy loss, or need disable TF32 for testing
purpose. To disable TF32, it is also possible to set environment
variable `NVIDIA_TF32_OVERRIDE = 0`. However, sometime we do not want to
use environment variable to avoid impacting other applications, or want
to have finer control (like one session using TF32, and another session
not). This provider option could help.

Here we add a provider option `use_tf32`. When `use_tf32 = 0`, we will
disable TF32 for float MatMul/GEMM in cublas. It applies to MatMulNBits,
Attention, LongformerAttention, PackedAttention,
PackedMultiHeadAttention operators when float GEMM is used internally in
the operator. Note that it will not impact other data type, like fp8
gemm could still use TF32 in accumulation.

Previously, cublasGemmStridedBatchedHelper does not use TF32 in
inference. Here we enabled TF32 by default, so we might observe speed up
for FP32 transformers models on SM >= 80.

There is another PR that enables the option for cuDNN Conv later.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

#15407
#19288
@sophies927 sophies927 added the performance issues related to performance regressions label Feb 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. performance issues related to performance regressions
Projects
None yet
Development

No branches or pull requests

4 participants