Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT EP's inference results are abnormal. #21457

Open
c1aude opened this issue Jul 23, 2024 · 37 comments
Open

TensorRT EP's inference results are abnormal. #21457

c1aude opened this issue Jul 23, 2024 · 37 comments
Labels
ep:TensorRT issues related to TensorRT execution provider

Comments

@c1aude
Copy link

c1aude commented Jul 23, 2024

Describe the issue

Inference results are outputting abnormally when using YOLOv7 models with TensorRT EP.

We have confirmed that the results are normal when using CPU and CUDA.

The issue was reproducible in versions 1.18.0 to 1.18.1 using TensorRT 10, and did not occur in versions 1.17.3 and earlier using TensorRT 8.6.1.6.

When using TensorRT 10, are there any other actions required when converting pytorch models to onnx as opposed to using TensorRT8?

Tensor RT result:

image

CPU or CUDA Result:

image

To reproduce

The code we used for testing is shown below.

import numpy as np
import onnxruntime as ort
import cv2
import matplotlib.pyplot as plt
import torch

class YOLOv7:
    
    def __init__(self, onnx_model, input_image, confidence_thresh, iou_thresh):
        self.onnx_model = onnx_model
        self.input_image = input_image
        self.confidence_thresh = confidence_thresh
        self.iou_thresh = iou_thresh
        self.classes = ["0","1", "2", "3", "4", "5", "6", "7", "8", "9",
                        "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
                        "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
                        "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
                        "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z",
                        "#", ".", "_", "carve", "print", "*", "+", "cut", "double", "copper",
                        ",", "/", "(", ")", "&", ":", "_", "~", "%", "=", "<", ">"]
        self.color_palette = np.random.uniform(0, 255, size = (len(self.classes),3))
        
    def draw_detections(self, img, box, score, class_id):
        x1, y1, w, h = box
        
        color = self.color_palette[class_id]
        cv2.rectangle(img,(int(x1), int(y1)), (int(x1+w), int(y1+h)), color, 2)
        label = self.classes[class_id]
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        label_x = x1
        label_y = y1 - 10 if y1 -10 > label_height else y1 + 10
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
        
    def preprocess(self):
        self.img = cv2.imread(self.input_image)
        self.img_height, self.img_width = self.img.shape[:2]
        
        img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.input_width, self.input_height))
        image_data = np.array(img) / 255.0
        image_data = np.transpose(image_data, (2, 0 ,1))
        image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
        
        return image_data
    
    def postprocess(self, input_image, output):
        output = output[0]
        print(output.shape)
        rows = output.shape[0]
        boxes = []
        scores = []
        class_ids = []
        
        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height
        print(output[0])
        for i in range(rows):
            classes_scores = output[i][6]
            
            if classes_scores >= self.confidence_thresh:
                class_id = output[i][5]
                
                x1, y1, x2, y2 = output[i][1], output[i][2], output[i][3], output[i][4]
                
                left = int(x1 * x_factor)
                top = int(y1 * y_factor)
                width = int((x2 - x1) * x_factor)
                height = int((y2 - y1) * y_factor)

                class_ids.append(class_id)
                scores.append(classes_scores)
                boxes.append([left, top, width, height])
                
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thresh, self.iou_thresh)
        
        for i in indices:
            box = boxes[i]
            score = scores[i]
            class_id = int(class_ids[i])
            self.draw_detections(input_image, box, score, class_id)
        return input_image

    def inference(self):
        providers = [
            ('CPUExecutionProvider', {
            'intra_op_num_threads': 4,  # 단일 연산에 사용할 최대 스레드 수
            'inter_op_num_threads': 1   # 여러 연산 간에 사용할 최대 스레드 수
            }),
            
            # ('CUDAExecutionProvider', {
            # 'device_id': 0,            # 사용할 GPU ID
            # 'arena_extend_strategy': 'kNextPowerOfTwo',
            # 'gpu_mem_limit': 2 * 1024 * 1024 * 1024,  # GPU 메모리 제한 (2GB)
            # 'cudnn_conv_algo_search': 'EXHAUSTIVE',
            # 'do_copy_in_default_stream': True,
            # 'enable_cuda_graph': False,  # CUDA graph 최적화 비활성화
            # }),
            
            # ('TensorrtExecutionProvider', {
            #     'device_id': 0,            # 사용할 GPU ID
            #     'trt_max_partition_iterations': 10, # 최적화를 위한 반복 횟수
            #     'trt_max_workspace_size': 2 * 1024 * 1024 * 1024,  # GPU 메모리 제한 (2GB)
            #     'trt_min_subgraph_size': 1, # 최소 생성 서브 그래프 개수
            #     'trt_engine_cache_enable': False,   # 캐시 저장 여부
            #     'trt_fp16_enable': True  # FP16 활성화
            #     #'trt_int8_enable': True   # Int8 활성화
            #     })
            ]       
        session = ort.InferenceSession(self.onnx_model, providers=providers)
        model_inputs = session.get_inputs()
        
        input_shape = model_inputs[0].shape
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]
        
        img_data = self.preprocess()
        outputs = session.run(None, {model_inputs[0].name: img_data})
        return self.postprocess(self.img, outputs)

detection = YOLOv7(r"E:\YOLOv7\best_562_0712.onnx", r"E:\TEST\test.png", 0.9, 0.5)
output_image = detection.inference()

plt.imshow(output_image)

Urgency

No response

Platform

Windows

OS Version

Windows 11

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.18.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

CUDA 11.8, Cudnn 8.9.7, TensorRT 10.2.0.19

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider platform:windows issues related to the Windows platform labels Jul 23, 2024
@yf711
Copy link
Contributor

yf711 commented Jul 23, 2024

Hi @c1aude could you also share the model/the requirements.txt of your python env/test image to repro this issue?

@c1aude c1aude closed this as completed Jul 24, 2024
@c1aude
Copy link
Author

c1aude commented Jul 24, 2024

Hi @c1aude could you also share the model/the requirements.txt of your python env/test image to repro this issue?

Hello @yf711 The env and model files used for inference.

image used for inference :
test

env requirements.txt :
https://drive.google.com/file/d/16wj3sa0JFyBOTpPit_2iqBU0trZtITSw/view?usp=sharing

test model:
https://drive.google.com/file/d/12GZKzMf5Pq1_qgKiwPF9HOCBwWM_jvsz/view?usp=sharing

@c1aude c1aude reopened this Jul 24, 2024
@sophies927 sophies927 removed ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform labels Jul 25, 2024
Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Aug 25, 2024
@jywu-msft jywu-msft removed the stale issues that have not been addressed in a while; categorized by a bot label Aug 26, 2024
@jywu-msft
Copy link
Member

adding a note to further debug this issue

@jingyanwangms
Copy link
Contributor

Hi @c1aude, thank you for detailed repro information. Unfortunately, when I try to run the script, I get below error. Can you please verify the script and share a version that works?
Traceback (most recent call last):
File "C:\Users\jingywa\OneDrive - Microsoft\source\21457\repro.py", line 125, in
output_image = detection.inference()
File "C:\Users\jingywa\OneDrive - Microsoft\source\21457\repro.py", line 122, in inference
return self.postprocess(self.img, outputs)
File "C:\Users\jingywa\OneDrive - Microsoft\source\21457\repro.py", line 64, in postprocess
if classes_scores >= self.confidence_thresh:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@BengtGustafsson
Copy link
Contributor

Is there any progress on this one? We get a similar problem with some of our networks but not all. I'm not sure it worked on onnxruntime 1.17 and TRT 8.6.1.6 that we used before but I think so. Now with 1.19.2 and TrT 10.4.0.26 it definitely differs from the other providers.

@jywu-msft
Copy link
Member

Is there any progress on this one? We get a similar problem with some of our networks but not all. I'm not sure it worked on onnxruntime 1.17 and TRT 8.6.1.6 that we used before but I think so. Now with 1.19.2 and TrT 10.4.0.26 it definitely differs from the other providers.

can you provide some repro test cases for @jingyanwangms to investigate further? thanks!

@BengtGustafsson
Copy link
Contributor

I log errors from the onnx log and it sometimes writes:

onnxruntime: [2024-09-17 16:10:25 ERROR] Error Code: 9: Skipping tactic 0xaa15e43058248292 due to exception canImplement1 [tensorrt_execution_provider.h:88 onnxruntime::TensorrtLogger::log]

To me this does not sound like an error, more like a warning. I have no idea what the hex code is but I get the same warning with many different codes.

@jingyanwangms
Copy link
Contributor

Yes this is a warning. It indicates TensorRT is skipping a specific tactic (optimization approach) due to an internal issue in the implementation (canImplement1). This should not block running your model on TensorRT since other tactics should kick in.
On our side, it is something that we can investigate. Can you please share detailed repro how you encountered the warning? Please do so in a separate issue since it seems to be different issue.

@samsonyilma
Copy link

samsonyilma commented Sep 17, 2024

I too observe similar behavior; different results with CUDAExecutionProvider vs TensorrtExecutionProvider
when using onnxruntime-gpu==1.19.2 with tensorrt=10.4.0

I get same outputs (CUDA vs TensorRT providers) when using onnxruntime-gpu==1.17.1 with tensorrt==8.6.0

I can provide the script I used in this issue...

The script I am using :-

import numpy as np
import onnxruntime as ort
import requests
import timm
import torch
from matplotlib import pyplot as plt
from PIL import Image

# URL for ImageNet class labels
image_net_labels_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"


def read_class_labels(url):
    response = requests.get(url)
    response.raise_for_status()  # Ensure we notice bad responses
    text = response.text
    dict_text = eval(text)
    return dict_text


def read_test_image(image_url="http://images.cocodataset.org/val2017/000000039769.jpg"):

    # load image from url using PIL
    img = Image.open(requests.get(image_url, stream=True).raw)

    # resize image to 224x224
    img = img.resize((224, 224))

    # convert image to numpy array
    img_array = np.array(img)

    # change the order of the channels from HWC to CHW
    img_array = np.transpose(img_array, (2, 0, 1))
    img_array = np.ascontiguousarray(img_array)

    img_array = (img_array - 127.5) / 127.5

    input_tensor = torch.tensor(img_array, dtype=torch.float32).unsqueeze(0)

    return input_tensor


def create_model():
    model = timm.create_model(
        "swin_base_patch4_window7_224", pretrained=True, num_classes=1000, in_chans=3, img_size=(224, 224)
    )
    return model


def export_to_onnx(model, input_shape, output_path, export_device="cuda"):
    model.eval()
    sample = torch.randn(input_shape).unsqueeze(0).to(export_device)
    torch.onnx.export(
        model.to(export_device),
        sample,
        output_path,
        export_params=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )
    return output_path


def run_onnx_inference(onnx_model_path, input_tensor, cuda_device_id=0, EP_name="CUDAExecutionProvider"):
    """ 
    Run inference using ONNX model.
    Args: 
        onnx_model_path (str): Path to the ONNX model
        input_tensor (torch.Tensor): Input tensor
        cuda_device_id (int): CUDA device id
        EP_name (str): Execution provider name
    """
    try:
        ort_session = ort.InferenceSession(
            onnx_model_path,
            sess_options=None,
            providers=[(EP_name, {"device_id": cuda_device_id})],
        )
        ort_input = {"input": input_tensor.numpy()}
        outputs = ort_session.run(None, ort_input)
        return outputs[0]
    except Exception as e:
        raise ValueError(f"Failed to run model with EP {EP_name}: {e}")


def top_k_class_ids(predictions, k=5):
    top_k = np.argsort(predictions, axis=1)[:, -k:][:, ::-1]
    return top_k


def compare_execution_provider_outputs():

    class_labels = read_class_labels(image_net_labels_url)

    model = create_model()
    onnx_output_filename = "./test_swin.onnx"

    export_to_onnx(model, (3, 224, 224), onnx_output_filename)

    input_tensor = read_test_image()

    for ep_name in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
        output = run_onnx_inference(onnx_model_path=onnx_output_filename, input_tensor=input_tensor, EP_name=ep_name)
        top_k = top_k_class_ids(output)[0]

        print(f"\n{ep_name} outputs:")
        for k, class_id in enumerate(top_k):
            print(
                f"top-{k+1} label = {class_labels[class_id]:30} conf= {output[0][class_id]:.3f}, class_id = {class_id}"
            )


if __name__ == "__main__":

    compare_execution_provider_outputs()

@BengtGustafsson
Copy link
Contributor

BengtGustafsson commented Sep 18, 2024

That seems like a clear-cut case. Our case involves proprietary models and C++ code that I can't share. We don't use the python bindings so unfortunately it would be a big job to create a repro.

In our case we get small differences all over the output image on the order of 1e-3. The errors are in the range of what would be expected if the model was running in 16 bits mode. As I understand it TRT 10 now sets precision from the onnx data. Maybe it has the defaults wrong if no precision is set? Our coefficient tensors are 32 bit floats though.

@samsonyilma
Copy link

That seems like a clear-cut case. Our case involves proprietary models and C++ code that I can't share. We don't use the python bindings so unfortunately it would be a big job to create a repro.

In our case we get small differences all over the output image on the order of 1e-3. The errors are in the range of what would be expected if the model was running in 16 bits mode. As I understand it TRT 10 now sets precision from the onnx data. Maybe it has the defaults wrong if no precision is set? Our coefficient tensors are 32 bit floats though.

I have tried setting trt_fp16_enable and trt_fp16_enable to False. The ONNX exported in the script should be FP32. Not sure if there is another flag I can set for precision ?

The results I get are not close at all. I run into same issue with another model architecture as well (which I can't share).

Output of the above script -
with onnxruntime-gpu=1.17.1 and tensorrt=8.6.0, both EP's give correct output (the input image has two cats ).

CUDAExecutionProvider outputs:
top-1 label = Egyptian cat                   conf= 6.487, class_id = 285
top-2 label = tabby, tabby cat               conf= 5.889, class_id = 281
top-3 label = tiger cat                      conf= 5.149, class_id = 282
top-4 label = parachute, chute               conf= 3.801, class_id = 701
top-5 label = remote control, remote         conf= 3.738, class_id = 761

TensorrtExecutionProvider outputs:
top-1 label = Egyptian cat                   conf= 6.487, class_id = 285
top-2 label = tabby, tabby cat               conf= 5.889, class_id = 281
top-3 label = tiger cat                      conf= 5.149, class_id = 282
top-4 label = parachute, chute               conf= 3.801, class_id = 701
top-5 label = remote control, remote         conf= 3.738, class_id = 761

with onnxruntime-gpu=1.19.2 with tensorrt=10.4.0, the outputs are completely different

 CUDAExecutionProvider outputs:
top-1 label = Egyptian cat                   conf= 6.487, class_id = 285
top-2 label = tabby, tabby cat               conf= 5.889, class_id = 281
top-3 label = tiger cat                      conf= 5.149, class_id = 282
top-4 label = parachute, chute               conf= 3.801, class_id = 701
top-5 label = remote control, remote         conf= 3.738, class_id = 761

TensorrtExecutionProvider outputs:
top-1 label = Appenzeller                    conf= 9.482, class_id = 240
top-2 label = toyshop                        conf= 8.440, class_id = 865
top-3 label = Bernese mountain dog           conf= 7.478, class_id = 239
top-4 label = scorpion                       conf= 7.301, class_id = 71
top-5 label = sewing machine                 conf= 7.164, class_id = 786

@jingyanwangms
Copy link
Contributor

jingyanwangms commented Sep 18, 2024

@samsonyilma Thank you for the simple repro. Yes I can see different result CUDA vs TensorRT with onnxruntime-gpu=1.19.2 with tensorrt=10.4.0. We're investigating on our side

@BengtGustafsson
Copy link
Contributor

For us it seems to only happen on Windows, we get correct results on Linux. We will however continue checking that this info is correct, we could have messed up the version increase on Linux or someething like that.

@c1aude
Copy link
Author

c1aude commented Sep 19, 2024

Yes this is a warning. It indicates TensorRT is skipping a specific tactic (optimization approach) due to an internal issue in the implementation (canImplement1). This should not block running your model on TensorRT since other tactics should kick in. On our side, it is something that we can investigate. Can you please share detailed repro how you encountered the warning? Please do so in a separate issue since it seems to be different issue.

I tried running it again with that code and it worked fine.

It looks like the same issue has been reproduced by another commenter, but if you need to test my code, could you please make the following changes and test it?

On lines 64 and 77, change self.confidence _thres to 0.9 and self.iou_thresh to 0.5.

@jywu-msft
Copy link
Member

Yes this is a warning. It indicates TensorRT is skipping a specific tactic (optimization approach) due to an internal issue in the implementation (canImplement1). This should not block running your model on TensorRT since other tactics should kick in. On our side, it is something that we can investigate. Can you please share detailed repro how you encountered the warning? Please do so in a separate issue since it seems to be different issue.

I tried running it again with that code and it worked fine.

It looks like the same issue has been reproduced by another commenter, but if you need to test my code, could you please make the following changes and test it?

On lines 64 and 77, change self.confidence _thres to 0.9 and self.iou_thresh to 0.5.

thanks we will look at your repro case too.

@BengtGustafsson
Copy link
Contributor

I mentioned above that it works on Linux. But this may not be a Linux/Windows issue, it could also be that we don't include the CUDA provider in the Linux build but we do in the Windows build.

May I ask c1aude and samsonyilma which OS you're on and which providers you have included in your onnxruntime builds or downloaded packages?

@samsonyilma
Copy link

I am using Ubuntu 22.04, python API with onnxruntime-gpu & tensorrt packages.

@c1aude
Copy link
Author

c1aude commented Sep 20, 2024

I'm using Window 11
building from source
The environment is CUDA 11.8, Cudnn 8.9.7, TensorRT 10.2.0.19
Compile is Visual C++ 17.8 in Visual Studio 2022

@BengtGustafsson
Copy link
Contributor

A colleague of mine removed the last layers until the error disappeared and then added the tentative culprit layer. This was a maxpool layer, but we guess that there is some optimization involving the preceding layers. The sequence is conv/relu/maxpool and our guess is that for some reason the input to maxpool is truncated to 16 bit float although our network is 32 bit float in all its parts. This is at least consistent with the magnitude of the errors.

@BengtGustafsson
Copy link
Contributor

Here is a complete test kit with a python program and an onnx file. As demonstrated, when running on a 128x128 image there is no diff between CPU and TRT but with 256x256 there is a difference. that kills our unit tests.

ConvReluMaxpoolIssue.zip

@BengtGustafsson
Copy link
Contributor

BengtGustafsson commented Sep 23, 2024

We no longer think it is a float16 by mistake issue, but maybe that the optimization moves to another algorithm with larger images, for instance a FFT based implementation that may be too inexact.

@jingyanwangms
Copy link
Contributor

jingyanwangms commented Sep 30, 2024

@c1aude @BengtGustafsson
We have trouble repro your error on our end. We tried running the provided script with latest main (build from source) and do not see difference in CPU vs tensorrt output. I think @Chi tried 1.19 or 1.18 for Yolo7 and also cannot repro
Here's what we suggest:

  1. Can you try running the model with trtexec and see if the model is giving the same error?
    trtexec is located under tensorrt installation path tensorrt/bin/trtexec
    trtexec --onnx=<model-path> --verbose --dumpOutput --iterations=1 --loadInputs='input'
    Here's how you can save input tensor
data = np.asarray(input_tensor, dtype=np.float32)
data.tofile("input")
  1. Some of the optimization might be hardward specific. That could be a factor. What's the exact onnxruntime-gpu, tensorrrt and hardware you're using?

@jingyanwangms
Copy link
Contributor

@BengtGustafsson in our testing, we can see variance on A100 but not on V100. So it's architecture dependent. What GPU architecture are you using?
We asked Nvidia people in our sync meeting. They suggested 1e-5 ~ 1e-3 for accuracy tolerance, I cannot find reference to this in their official documentation. In our test we saw deviation of 0.00013554096 which falls within this range

@jingyanwangms
Copy link
Contributor

@BengtGustafsson can you give export NVIDIA_TF32_OVERRIDE=0 a try? This disables TF32 Tensor Cores optimization which is on by default in A100. My testing (ORT 1.18.2, TRT10.4) shows there's no more deviation after disabling it.

@jingyanwangms
Copy link
Contributor

jingyanwangms commented Oct 7, 2024

@c1aude
I still get the same error
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() from if classes_scores >= self.confidence_thresh:. The dimension of output (1, 64512, 89) is not what the post processing script expect.
I used the onnx model that was shared earlier (best_562_0712_e2e.onnx) instead of what's called for in the script (best_562_0712.onnx)
Anyway if I directly compare the model output tensor values of cpu and tensorrt. They look close. The difference is around 1e-2, on the large side but for classification problem this should not cause very different result.
CPU:
[[2.8125420e+00 7.8375378e+00 6.8491874e+00 ... 1.4686584e-04
7.7545643e-05 8.5532665e-05]
[1.1132017e+01 7.9093590e+00 1.7115019e+01 ... 1.3202429e-04
7.6740980e-05 8.2492828e-05]
[1.9796419e+01 8.6261234e+00 3.1406281e+01 ... 1.1369586e-04
8.1747770e-05 8.7499619e-05]
...
[9.4306152e+02 9.9806207e+02 2.1758392e+02 ... 9.0873241e-04
1.4138222e-03 1.3399422e-03]
[9.7226398e+02 9.9804059e+02 1.9094728e+02 ... 9.4437599e-04
1.6762912e-03 1.3577640e-03]
[1.0059143e+03 9.9801050e+02 2.3049202e+02 ... 1.0394752e-03
1.3738573e-03 1.0043979e-03]]

TensorRT:
[[[2.80859375e+00 7.83593750e+00 6.83984375e+00 ... 1.45673752e-04
7.78436661e-05 8.57114792e-05]
[1.11250000e+01 7.89843750e+00 1.71250000e+01 ... 1.31487846e-04
7.65919685e-05 8.23736191e-05]
[1.97812500e+01 8.63281250e+00 3.14218750e+01 ... 1.13368034e-04
8.23736191e-05 8.77380371e-05]
...
[9.43000000e+02 9.98000000e+02 2.17625000e+02 ... 9.07421112e-04
1.41048431e-03 1.33514404e-03]
[9.72500000e+02 9.98000000e+02 1.91250000e+02 ... 9.43660736e-04
1.66797638e-03 1.35135651e-03]
[1.00600000e+03 9.98000000e+02 2.30750000e+02 ... 1.03664398e-03
1.36756897e-03 1.00040436e-03]]]

If I use fp32 for TensorRT EP, output becomes much closer, ~ 1e-3
[[[2.81318426e+00 7.83807373e+00 6.84823132e+00 ... 1.46798266e-04
7.75728695e-05 8.55009494e-05]
[1.11298866e+01 7.90688133e+00 1.71119938e+01 ... 1.31941168e-04
7.67640740e-05 8.24807939e-05]
[1.97889137e+01 8.62717152e+00 3.13998184e+01 ... 1.13563103e-04
8.17740729e-05 8.75289552e-05]
...
[9.43060730e+02 9.98066040e+02 2.17566589e+02 ... 9.08780610e-04
1.41405826e-03 1.34034338e-03]
[9.72270386e+02 9.98048523e+02 1.90975510e+02 ... 9.44551546e-04
1.67512242e-03 1.35688041e-03]
[1.00591754e+03 9.98025635e+02 2.30555847e+02 ... 1.03968882e-03
1.37228891e-03 1.00394525e-03]]]

With export NVIDIA_TF32_OVERRIDE=0, output is very close ~1e-6
[[[2.8125434e+00 7.8375340e+00 6.8491859e+00 ... 1.4692696e-04
7.7577155e-05 8.5513588e-05]
[1.1132016e+01 7.9093599e+00 1.7115021e+01 ... 1.3195853e-04
7.6756674e-05 8.2475213e-05]
[1.9796421e+01 8.6261225e+00 3.1406290e+01 ... 1.1360870e-04
8.1757702e-05 8.7516019e-05]
...
[9.4306152e+02 9.9806207e+02 2.1758388e+02 ... 9.0868835e-04
1.4138712e-03 1.3399873e-03]
[9.7226398e+02 9.9804059e+02 1.9094728e+02 ... 9.4442023e-04
1.6762679e-03 1.3577911e-03]
[1.0059143e+03 9.9801050e+02 2.3049196e+02 ... 1.0394878e-03
1.3738739e-03 1.0044393e-03]]]
I'll check with the team on our suggestions for fp32 vs fp16 output discrepancy. Meanwhile, here is nvidia suggestion on accuracy.

@c1aude
Copy link
Author

c1aude commented Oct 8, 2024

@jingyanwangms
Unfortunately, I realized that the onnx file I provided was incorrect.

For testing purposes, here is the onnx file that we converted back to a .pt file.
https://drive.google.com/file/d/1Y2Nyjjii0zTLcPFTn2bLdMXVyNvqvkA9/view?usp=sharing

For the newly converted onnx file, trtexec passes, but for the dumped output, the result is displayed as below.
"name” : “output”
“dimensions” : ”-1x7”
“values” : [ ]

The output from the CPU and TensorRT is shown below.

CPU:
[0.0000000e+00, 4.4677078e+02, 7.1414655e+02, 4.8075333e+02, 7.7660529e+02, 5.4000000e+01, 9.3979961e-01],
[0.0000000e+00, 6.2021057e+02, 4.8042380e+02, 6.5300745e+02, 5.4063220e+02, 5.4000000e+01, 9.3742400e-01],
[0.0000000e+00, 5.2568127e+02, 7.1490686e+02, 5.5946606e+02, 7.7716553e+02, 2.0000000e+00, 9.3704671e-01],
[0.0000000e+00, 4.8630933e+02, 7.1455713e+02, 5.2070258e+02, 7.7725500e+02, 2.0000000e+00, 9.3697560e-01],
[0.0000000e+00, 5.8296631e+02, 5.5369794e+02, 6.1430652e+02, 6.1389105e+02, 3.8000000e+01, 9.3695939e-01],
[0.0000000e+00, 5.0879837e+02, 4.8047485e+02, 5.3939441e+02, 5.4027264e+02, 6.0000000e+00, 9.3668890e-01],
[0.0000000e+00, 3.9539548e+02, 4.7931345e+02, 4.2837576e+02, 5.4133594e+02, 4.6000000e+01, 9.3580389e-01],
[0.0000000e+00, 4.7167825e+02, 4.0781860e+02, 5.0296396e+02, 4.6727991e+02, 3.8000000e+01, 9.3566370e-01],
[0.0000000e+00, 4.3337311e+02, 5.5317072e+02, 4.6651141e+02, 6.1451154e+02, 4.3000000e+01, 9.3515205e-01],
[0.0000000e+00, 6.2102240e+02, 4.0778052e+02, 6.5267279e+02, 4.6844904e+02, 6.0000000e+00, 9.3481648e-01],
[0.0000000e+00, 3.6846997e+02, 7.1337030e+02, 4.0029083e+02, 7.7618292e+02, 4.2000000e+01, 9.3374938e-01],
[0.0000000e+00, 4.0746851e+02, 7.1414252e+02, 4.4143250e+02, 7.7686285e+02, 3.0000000e+00, 9.3292224e-01],
[0.0000000e+00, 5.8277606e+02, 4.8077115e+02, 6.1462140e+02, 5.4043500e+02, 3.0000000e+00, 9.3240434e-01],
[0.0000000e+00, 4.3421063e+02, 4.0774142e+02, 4.6557245e+02, 4.6819547e+02, 4.0000000e+01, 9.3209356e-01],
[0.0000000e+00, 3.9403558e+02, 5.5388678e+02, 4.2971735e+02, 6.1416327e+02, 4.0000000e+00, 9.3174213e-01],
[0.0000000e+00, 5.4469495e+02, 5.5259711e+02, 5.7639978e+02, 6.1391412e+02, 4.1000000e+01, 9.3040138e-01],
[0.0000000e+00, 6.0542896e+02, 7.1455646e+02, 6.3955334e+02, 7.7711102e+02, 9.0000000e+00, 9.3038446e-01],
[0.0000000e+00, 5.6628546e+02, 7.1490137e+02, 5.9957684e+02, 7.7637537e+02, 4.0000000e+00, 9.3033844e-01],
[0.0000000e+00, 4.7034698e+02, 4.7952795e+02, 5.0211658e+02, 5.4004779e+02, 4.1000000e+01, 9.2930132e-01],
[0.0000000e+00, 3.9689526e+02, 4.0735898e+02, 4.3005579e+02, 4.6771957e+02, 5.4000000e+01, 9.2828971e-01],
[0.0000000e+00, 4.3302881e+02, 4.8034964e+02, 4.6414935e+02, 5.4010651e+02, 4.0000000e+00, 9.2823547e-01],
[0.0000000e+00, 1.0720940e+02, 7.3440167e+02, 1.5167479e+02, 7.7617743e+02, 6.6000000e+01, 9.2797732e-01],
[0.0000000e+00, 5.8221155e+02, 4.0838736e+02, 6.1571814e+02, 4.6814658e+02, 4.0000000e+00, 9.2790771e-01],
[0.0000000e+00, 5.0858850e+02, 5.5369598e+02, 5.3975403e+02, 6.1421844e+02, 4.2000000e+01, 9.2737585e-01],
[0.0000000e+00, 5.4399890e+02, 4.8034518e+02, 5.7676672e+02, 5.4144305e+02, 4.0000000e+01, 9.2568499e-01],
[0.0000000e+00, 5.4611096e+02, 4.0779602e+02, 5.7671118e+02, 4.6802509e+02, 2.0000000e+00, 9.2544478e-01],
[0.0000000e+00, 6.1977032e+02, 5.5363629e+02, 6.5135797e+02, 6.1422748e+02, 4.7000000e+01, 9.2314225e-01],
[0.0000000e+00, 4.7027097e+02, 5.5330487e+02, 5.0327151e+02, 6.1475018e+02, 4.8000000e+01, 8.5506588e-01],
[0.0000000e+00, 6.4625714e+02, 7.1455139e+02, 6.7965399e+02, 7.7815332e+02, 4.8000000e+01, 6.0979420e-01],
[0.0000000e+00, 6.4640173e+02, 7.1473114e+02, 6.7941016e+02, 7.7817133e+02, 4.9000000e+01, 6.0894704e-01]

TensorRT:
[0.00000000e+00, 4.46775421e+02, 7.14117310e+02, 4.80748688e+02, 7.76643433e+02, 3.00000000e+00, 9.40011501e-01],
[0.00000000e+00, 6.20211182e+02, 4.80428650e+02, 6.53026367e+02, 5.40624573e+02, 3.00000000e+00, 9.37641263e-01],
[0.00000000e+00, 4.86311005e+02, 7.14555237e+02, 5.20709229e+02, 7.77251892e+02, 5.40000000e+01, 9.37415719e-01],
[0.00000000e+00, 5.82974609e+02, 5.53695374e+02, 6.14301514e+02, 6.13891296e+02, 4.10000000e+01, 9.37294006e-01],
[0.00000000e+00, 5.25686768e+02, 7.14919922e+02, 5.59472656e+02, 7.77147827e+02, 2.00000000e+00, 9.37086999e-01],
[0.00000000e+00, 5.08803497e+02, 4.80459625e+02, 5.39392639e+02, 5.40299194e+02, 6.00000000e+00, 9.36555922e-01],
[0.00000000e+00, 3.95394165e+02, 4.79318604e+02, 4.28377075e+02, 5.41348145e+02, 4.60000000e+01, 9.35925961e-01],
[0.00000000e+00, 4.33373566e+02, 5.53178467e+02, 4.66516266e+02, 6.14500244e+02, 4.00000000e+00, 9.35737729e-01],
[0.00000000e+00, 4.71697876e+02, 4.07830658e+02, 5.02951355e+02, 4.67266815e+02, 4.00000000e+01, 9.35402811e-01],
[0.00000000e+00, 6.21050354e+02, 4.07783356e+02, 6.52663635e+02, 4.68450958e+02, 4.00000000e+00, 9.34674740e-01],
[0.00000000e+00, 3.68465546e+02, 7.13373901e+02, 4.00284271e+02, 7.76199097e+02, 1.00000000e+00, 9.33819771e-01],
[0.00000000e+00, 4.07472931e+02, 7.14122070e+02, 4.41429169e+02, 7.76880737e+02, 3.00000000e+00, 9.32938814e-01],
[0.00000000e+00, 5.82778625e+02, 4.80781708e+02, 6.14613831e+02, 5.40427490e+02, 4.00000000e+01, 9.32536781e-01],
[0.00000000e+00, 4.34210815e+02, 4.07748230e+02, 4.65586731e+02, 4.68187866e+02, 5.40000000e+01, 9.32058394e-01],
[0.00000000e+00, 3.94049225e+02, 5.53911499e+02, 4.29710541e+02, 6.14152954e+02, 4.00000000e+00, 9.31950986e-01],
[0.00000000e+00, 5.66284729e+02, 7.14886658e+02, 5.99587585e+02, 7.76388977e+02, 2.00000000e+00, 9.30655420e-01],
[0.00000000e+00, 5.44713501e+02, 5.52607666e+02, 5.76400635e+02, 6.13912964e+02, 4.20000000e+01, 9.30561185e-01],
[0.00000000e+00, 6.05432556e+02, 7.14551147e+02, 6.39542419e+02, 7.77110474e+02, 4.00000000e+00, 9.30454373e-01],
[0.00000000e+00, 4.70345245e+02, 4.79519470e+02, 5.02114594e+02, 5.40056763e+02, 4.00000000e+00, 9.29276645e-01],
[0.00000000e+00, 3.96894196e+02, 4.07368164e+02, 4.30053741e+02, 4.67710266e+02, 5.40000000e+01, 9.28609908e-01],
[0.00000000e+00, 4.33044434e+02, 4.80328766e+02, 4.64143250e+02, 5.40119873e+02, 4.60000000e+01, 9.28506732e-01],
[0.00000000e+00, 1.07209076e+02, 7.34401245e+02, 1.51692139e+02, 7.76182495e+02, 1.00000000e+00, 9.28080320e-01],
[0.00000000e+00, 5.82211060e+02, 4.08378174e+02, 6.15725342e+02, 4.68153076e+02, 2.00000000e+00, 9.27607834e-01],
[0.00000000e+00, 5.08416809e+02, 5.54022888e+02, 5.39827332e+02, 6.14243591e+02, 4.80000000e+01, 9.27313626e-01],
[0.00000000e+00, 5.43996338e+02, 4.80356964e+02, 5.76752930e+02, 5.41449280e+02, 6.00000000e+00, 9.25769091e-01],
[0.00000000e+00, 5.46109802e+02, 4.07804352e+02, 5.76706970e+02, 4.68016510e+02, 1.00000000e+00, 9.24954176e-01],
[0.00000000e+00, 6.19758728e+02, 5.53634888e+02, 6.51377380e+02, 6.14229736e+02, 3.80000000e+01, 9.23013985e-01],
[0.00000000e+00, 4.70281067e+02, 5.53326294e+02, 5.03261047e+02, 6.14732300e+02, 4.30000000e+01, 8.55410874e-01],
[0.00000000e+00, 6.46267456e+02, 7.14581238e+02, 6.79637939e+02, 7.78107483e+02, 4.90000000e+01, 6.09435678e-01],
[0.00000000e+00, 6.46399841e+02, 7.14738525e+02, 6.79401794e+02, 7.78164307e+02, 9.00000000e+00, 6.08466506e-01]

When I diff'd and analyzed the results, it seems that the values in the 6th row, which determines the class in TensorRT, are all strange.

The PC I used is an RTX 4070 and I've tried building and testing with TensorRT-10.0.1.6 and TensorRT-10.2.0.19 and have the same issue.

@c1aude c1aude closed this as completed Oct 8, 2024
@c1aude c1aude reopened this Oct 8, 2024
@jingyanwangms
Copy link
Contributor

@c1aude Thank you for providing the onnx graph and pointing out where the different value is. I can repro the issue on with onnxruntime+TensorRT 10.4 now. But I see the same output as onnxruntime TensorRT EP in trtexec. Can you please clarify this?
Here's how I run trtexec
trtexec --onnx=best_562_0712.onnx --verbose --dumpOutput --iterations=1 --loadInputs='images'
input images is saved with

data = np.asarray(img_data, dtype=np.float32)
data.tofile("images")
``` after `img_data = self.preprocess()`

@BengtGustafsson
Copy link
Contributor

Thanks! Our differences disappear after setting NVIDIA_TF32_OVERRIDE=0.

I tested this on my A5000, we'll see what happens on the various GPU in our test park.

I could not detect a speed penalty for disabling this, so we'll just set it up.

Now I just wonder if there is a way to set this mode without using anenvironment variable. Even if we can do it in our program it isn't a good way of working in general. Note that we work entirely in C++.

@jingyanwangms
Copy link
Contributor

Thanks! Our differences disappear after setting NVIDIA_TF32_OVERRIDE=0.

I tested this on my A5000, we'll see what happens on the various GPU in our test park.

I could not detect a speed penalty for disabling this, so we'll just set it up.

Now I just wonder if there is a way to set this mode without using anenvironment variable. Even if we can do it in our program it isn't a good way of working in general. Note that we work entirely in C++.

As far as we know, there's no other way to set this environment variable. This is a nvidia setting so we don't control this. We'll ask nvidia in our sync meeting.

@BengtGustafsson
Copy link
Contributor

You can do it programmatically on the TRT level but not through onnxruntime it seems:

config->clearFlag(BuilderFlag::kTF32);

https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#tf32-inference-c

@jingyanwangms
Copy link
Contributor

We do not expose this option. In general you want to use this environment variable because it works for both tensorrrt and cuda because a model can fall back to cuda ep.

@jingyanwangms
Copy link
Contributor

@samsonyilma it's fixed in TensorRT 10.6 now. I verified your script. Can you try with TensorRT 10.6?

@samsonyilma
Copy link

@samsonyilma it's fixed in TensorRT 10.6 now. I verified your script. Can you try with TensorRT 10.6?

I tried running the script and got 'Target GPU SM 70 is not supported by this TensorRT release` error.
I am using Tesla V100-SXM2-16GB on Ubuntu 22.04.

2024-11-11 20:32:36.022673448 [E:onnxruntime:Default, tensorrt_execution_provider.h:88 log] [2024-11-11 20:32:36   ERROR] IBuilder::buildSerializedNetwork: Error Code 9: API Usage Error (Target GPU SM 70 is not supported by this TensorRT release.)
2024-11-11 20:32:36.022751127 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running TRTKernel_graph_main_graph_8383588888691045731_0 node. Name:'TensorrtExecutionProvider_TRTKernel_graph_main_graph_8383588888691045731_0_0' Status Message: TensorRT EP failed to create engine from network.

@jywu-msft
Copy link
Member

@samsonyilma it's fixed in TensorRT 10.6 now. I verified your script. Can you try with TensorRT 10.6?

I tried running the script and got 'Target GPU SM 70 is not supported by this TensorRT release` error. I am using Tesla V100-SXM2-16GB on Ubuntu 22.04.

2024-11-11 20:32:36.022673448 [E:onnxruntime:Default, tensorrt_execution_provider.h:88 log] [2024-11-11 20:32:36   ERROR] IBuilder::buildSerializedNetwork: Error Code 9: API Usage Error (Target GPU SM 70 is not supported by this TensorRT release.)
2024-11-11 20:32:36.022751127 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running TRTKernel_graph_main_graph_8383588888691045731_0 node. Name:'TensorrtExecutionProvider_TRTKernel_graph_main_graph_8383588888691045731_0_0' Status Message: TensorRT EP failed to create engine from network.

Volta architecture stopped being supported by TensorRT since 10.5 :(

@samsonyilma
Copy link

Huh - TensorRT dropping support for Volta is a bummer...

@jywu-msft
Copy link
Member

Huh - TensorRT dropping support for Volta is a bummer...

yeah, unfortunate.
Here's the announcement for your reference https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#rel-10-5-0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider
Projects
None yet
Development

No branches or pull requests

7 participants