Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is that possible to use trt file on onnxruntime? #17811

Closed
MinGiSa opened this issue Oct 6, 2023 · 6 comments
Closed

is that possible to use trt file on onnxruntime? #17811

MinGiSa opened this issue Oct 6, 2023 · 6 comments
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider platform:windows issues related to the Windows platform stale issues that have not been addressed in a while; categorized by a bot

Comments

@MinGiSa
Copy link

MinGiSa commented Oct 6, 2023

Describe the issue

I have converted an existing ONNX file to a TensorRT file (engine). I would like to perform inference using the converted TensorRT file with ONNX Runtime, but I am unsure of the process. When I attempt inference using the following code:

ortSession = ort.InferenceSession(engine, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider'])

I encounter the following error:

Traceback (most recent call last):
File "packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 405, in init
raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'")
TypeError: Unable to load from type '<class 'tensorrt.tensorrt.ICudaEngine'>'

It seems there is an issue with deserializing the TensorRT file saved through serialization. How can I resolve this problem?

To reproduce

import tensorrt as trt
import os
import cv2
import time
import numpy as np
from cuda import cuda
import warnings
import onnxruntime as ort

warnings.filterwarnings("ignore")

def preprocessImage(imagePath, imageSize):
img = cv2.imread(imagePath, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = cv2.dnn.blobFromImage(img, 1 / 255, imageSize, (0, 0, 0))

image[0][0] = (image[0][0] - 0.485) / 0.229
image[0][1] = (image[0][1] - 0.456) / 0.224
image[0][2] = (image[0][2] - 0.406) / 0.225
return image

def trtInference(engine, context, data):
nInput = np.sum([engine.binding_is_input(i) for i in range(engine.num_bindings)])
nOutput = engine.num_bindings - nInput
# print('nInput:', nInput)
# print('nOutput:', nOutput)

# for i in range(nInput):
#     print("Bind[%2d]:i[%2d]->" % (i, i), engine.get_binding_dtype(i), engine.get_binding_shape(i), context.get_binding_shape(i), engine.get_binding_name(i))
# for i in range(nInput,nInput+nOutput):
#     print("Bind[%2d]:o[%2d]->" % (i, i - nInput), engine.get_binding_dtype(i), engine.get_binding_shape(i), context.get_binding_shape(i), engine.get_binding_name(i))

bufferH = []
bufferH.append(np.ascontiguousarray(data.reshape(-1)))

for i in range(nInput, nInput + nOutput):
    bufferH.append(np.empty(context.get_binding_shape(i), dtype=trt.nptype(engine.get_binding_dtype(i))))

bufferD = []
for i in range(nInput + nOutput):
    bufferD.append(cuda.cuMemAlloc(bufferH[i].nbytes)[1])

for i in range(nInput):
    cuda.cuMemcpyHtoD(bufferD[i], bufferH[i].ctypes.data, bufferH[i].nbytes)

context.execute_v2(bufferD)

for i in range(nInput, nInput + nOutput):
    cuda.cuMemcpyDtoH(bufferH[i].ctypes.data, bufferD[i], bufferH[i].nbytes)
    
for b in bufferD:
    cuda.cuMemFree(b)  

return bufferH

os.environ['CUDA_MODULE_LOADING'] = 'LAZY'

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

trt_engine_path = r"trt engine path here"
imagePath = r'image path here'
imageSize = (580, 410)
batchSize = 1

processedImage = preprocessImage(imagePath, imageSize)

with open(trt_engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())

context = engine.create_execution_context()

trtStart = time.time()
trtOutputs = trtInference(engine, context, processedImage)
trtOutputs = np.array(trtOutputs[1]).reshape(batchSize, -1)
trtEnd = time.time()

print('--tensorrt--')
print(trtOutputs.shape)
print(trtOutputs[0][:10])
print(np.argmax(trtOutputs, axis=1))
print('Time: ', trtEnd - trtStart)

def trtInferenceWithONNXRuntime(ort_session, data):
ort_inputs = {ort_session.get_inputs()[0].name: data}
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs

ortSession = ort.InferenceSession(engine, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider'])

ortStart = time.time()
ortOutputs = trtInferenceWithONNXRuntime(ortSession, processedImage)
ortOutputs = np.array(ortOutputs[0]).reshape(batchSize, -1)
ortEnd = time.time()

print('--onnxruntime--')
print(ortOutputs.shape)
print(ortOutputs[0][:10])
print(np.argmax(ortOutputs, axis=1))
print('Time: ', ortEnd - ortStart)

Urgency

No response

Platform

Windows

OS Version

windows 11

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA, TensorRT

Execution Provider Library Version

torch 2.0.1, CUDA 11.7

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider platform:windows issues related to the Windows platform labels Oct 6, 2023
@wschin
Copy link
Contributor

wschin commented Oct 6, 2023

The only supporting model format in ORT is ONNX. If you want to run a TensorRT model, you need to either (1) convert it back to ONNX or (2) use the existing ONNX model directly. If the goal is to compare performance, using existing ONNX model should be enough. As you specified ['TensorrtExecutionProvider', 'CUDAExecutionProvider'], ORT will do its best to find the fastest configuration to execute the graph with both TensorRT and native CUDA. You don't need to worry about performance. On the other hand, if ORT doesn't perform well, it's something we want to fix. Thank you.

@tianleiwu
Copy link
Contributor

@MinGiSa,

You can follow the stable diffusion example to run TensorRT EP. The key part is to construct provider options from input profile to support dynamic shape:

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py

The other part is just like the other provider: create a session from onnx file, then run model with IO/Binding and CUDA graph.

@wschin
Copy link
Contributor

wschin commented Oct 6, 2023

@tianleiwu , do you have a simpler example? The example mentioned is hard to read because

  1. it's model is complicated
  2. the coding style is not functional, so the user needs to go through the class hierarchy to understand the overall behavior.
  3. it's not a standalone script the user can just execute and observe.

If we don't have such an example, we need to create one asap. PyTorch does very good on this kind of introduction examples.

@tianleiwu
Copy link
Contributor

@wschin, You are right. The stable diffusion is complex, and use many advanced settings, so it might not be good as tutorial.

Simple examples can found in document: https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html

@wschin
Copy link
Contributor

wschin commented Oct 6, 2023

@MinGiSa, as @tianleiwu mentioned, you need to feed your existing ONNX file to launch onnxruntime with TensorRT and CUDA execution providers. Please don't feed TensorRT files to onnxruntime directly. Note that the conversion pipeline ONNX file -> TensorRT engine -> ONNX file is not supported.

Copy link
Contributor

github-actions bot commented Nov 6, 2023

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Nov 6, 2023
@MinGiSa MinGiSa closed this as completed Nov 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider platform:windows issues related to the Windows platform stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

3 participants