-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorRT EP's inference results are abnormal. #21457
Comments
Hi @c1aude could you also share the model/the requirements.txt of your python env/test image to repro this issue? |
Hello @yf711 The env and model files used for inference. env requirements.txt : test model: |
This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details. |
adding a note to further debug this issue |
Hi @c1aude, thank you for detailed repro information. Unfortunately, when I try to run the script, I get below error. Can you please verify the script and share a version that works? |
Is there any progress on this one? We get a similar problem with some of our networks but not all. I'm not sure it worked on onnxruntime 1.17 and TRT 8.6.1.6 that we used before but I think so. Now with 1.19.2 and TrT 10.4.0.26 it definitely differs from the other providers. |
can you provide some repro test cases for @jingyanwangms to investigate further? thanks! |
I log errors from the onnx log and it sometimes writes: onnxruntime: [2024-09-17 16:10:25 ERROR] Error Code: 9: Skipping tactic 0xaa15e43058248292 due to exception canImplement1 [tensorrt_execution_provider.h:88 onnxruntime::TensorrtLogger::log] To me this does not sound like an error, more like a warning. I have no idea what the hex code is but I get the same warning with many different codes. |
Yes this is a warning. It indicates TensorRT is skipping a specific tactic (optimization approach) due to an internal issue in the implementation (canImplement1). This should not block running your model on TensorRT since other tactics should kick in. |
I too observe similar behavior; different results with CUDAExecutionProvider vs TensorrtExecutionProvider I get same outputs (CUDA vs TensorRT providers) when using onnxruntime-gpu==1.17.1 with tensorrt==8.6.0 I can provide the script I used in this issue... The script I am using :-
|
That seems like a clear-cut case. Our case involves proprietary models and C++ code that I can't share. We don't use the python bindings so unfortunately it would be a big job to create a repro. In our case we get small differences all over the output image on the order of 1e-3. The errors are in the range of what would be expected if the model was running in 16 bits mode. As I understand it TRT 10 now sets precision from the onnx data. Maybe it has the defaults wrong if no precision is set? Our coefficient tensors are 32 bit floats though. |
I have tried setting The results I get are not close at all. I run into same issue with another model architecture as well (which I can't share). Output of the above script -
with onnxruntime-gpu=1.19.2 with tensorrt=10.4.0, the outputs are completely different
|
@samsonyilma Thank you for the simple repro. Yes I can see different result CUDA vs TensorRT with onnxruntime-gpu=1.19.2 with tensorrt=10.4.0. We're investigating on our side |
For us it seems to only happen on Windows, we get correct results on Linux. We will however continue checking that this info is correct, we could have messed up the version increase on Linux or someething like that. |
I tried running it again with that code and it worked fine. It looks like the same issue has been reproduced by another commenter, but if you need to test my code, could you please make the following changes and test it? On lines 64 and 77, change self.confidence _thres to 0.9 and self.iou_thresh to 0.5. |
thanks we will look at your repro case too. |
I mentioned above that it works on Linux. But this may not be a Linux/Windows issue, it could also be that we don't include the CUDA provider in the Linux build but we do in the Windows build. May I ask c1aude and samsonyilma which OS you're on and which providers you have included in your onnxruntime builds or downloaded packages? |
I am using Ubuntu 22.04, python API with onnxruntime-gpu & tensorrt packages. |
I'm using Window 11 |
A colleague of mine removed the last layers until the error disappeared and then added the tentative culprit layer. This was a maxpool layer, but we guess that there is some optimization involving the preceding layers. The sequence is conv/relu/maxpool and our guess is that for some reason the input to maxpool is truncated to 16 bit float although our network is 32 bit float in all its parts. This is at least consistent with the magnitude of the errors. |
Here is a complete test kit with a python program and an onnx file. As demonstrated, when running on a 128x128 image there is no diff between CPU and TRT but with 256x256 there is a difference. that kills our unit tests. |
We no longer think it is a float16 by mistake issue, but maybe that the optimization moves to another algorithm with larger images, for instance a FFT based implementation that may be too inexact. |
@c1aude @BengtGustafsson
|
@BengtGustafsson in our testing, we can see variance on A100 but not on V100. So it's architecture dependent. What GPU architecture are you using? |
@BengtGustafsson can you give |
@c1aude TensorRT: If I use fp32 for TensorRT EP, output becomes much closer, ~ 1e-3 With |
@jingyanwangms For testing purposes, here is the onnx file that we converted back to a .pt file. For the newly converted onnx file, trtexec passes, but for the dumped output, the result is displayed as below. The output from the CPU and TensorRT is shown below. CPU: TensorRT: When I diff'd and analyzed the results, it seems that the values in the 6th row, which determines the class in TensorRT, are all strange. The PC I used is an RTX 4070 and I've tried building and testing with TensorRT-10.0.1.6 and TensorRT-10.2.0.19 and have the same issue. |
@c1aude Thank you for providing the onnx graph and pointing out where the different value is. I can repro the issue on with onnxruntime+TensorRT 10.4 now. But I see the same output as onnxruntime TensorRT EP in trtexec. Can you please clarify this? data = np.asarray(img_data, dtype=np.float32)
data.tofile("images")
``` after `img_data = self.preprocess()` |
Thanks! Our differences disappear after setting NVIDIA_TF32_OVERRIDE=0. I tested this on my A5000, we'll see what happens on the various GPU in our test park. I could not detect a speed penalty for disabling this, so we'll just set it up. Now I just wonder if there is a way to set this mode without using anenvironment variable. Even if we can do it in our program it isn't a good way of working in general. Note that we work entirely in C++. |
As far as we know, there's no other way to set this environment variable. This is a nvidia setting so we don't control this. We'll ask nvidia in our sync meeting. |
You can do it programmatically on the TRT level but not through onnxruntime it seems: config->clearFlag(BuilderFlag::kTF32); https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#tf32-inference-c |
We do not expose this option. In general you want to use this environment variable because it works for both tensorrrt and cuda because a model can fall back to cuda ep. |
@samsonyilma it's fixed in TensorRT 10.6 now. I verified your script. Can you try with TensorRT 10.6? |
I tried running the script and got 'Target GPU SM 70 is not supported by this TensorRT release` error.
|
Volta architecture stopped being supported by TensorRT since 10.5 :( |
Huh - TensorRT dropping support for Volta is a bummer... |
yeah, unfortunate. |
Describe the issue
Inference results are outputting abnormally when using YOLOv7 models with TensorRT EP.
We have confirmed that the results are normal when using CPU and CUDA.
The issue was reproducible in versions 1.18.0 to 1.18.1 using TensorRT 10, and did not occur in versions 1.17.3 and earlier using TensorRT 8.6.1.6.
When using TensorRT 10, are there any other actions required when converting pytorch models to onnx as opposed to using TensorRT8?
Tensor RT result:
CPU or CUDA Result:
To reproduce
The code we used for testing is shown below.
Urgency
No response
Platform
Windows
OS Version
Windows 11
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.18.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
TensorRT
Execution Provider Library Version
CUDA 11.8, Cudnn 8.9.7, TensorRT 10.2.0.19
The text was updated successfully, but these errors were encountered: