Skip to content

Using TensorRT for Inference Model Deployment.

Notifications You must be signed in to change notification settings

dongdql/Tensorrt-CV

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

48 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Documentation Documentation Documentation Documentation Documentation

Tensorrt-CV

This implementation is totally for deployment concern. In most cases, we got an DNN model trained by python (like Pytorch) and the goal is low latency (measured by FPS) and high accuracy (measured by MSE) in product. Here, Tensorrt is in use.

For my expericence, the most efficient achievement is to convert Pytorch model to ONNX and use onnx parser in tensorrt, over rebuilding the whole net and loading the weights on tensorrt, particularly when the preprocess and postprocess are in consideration. The calculation of preprocess or postprocess will be treated as a part of the net, saved in tensorrt engine and runing on gpu.

But in some cases, there may have operator-unsupport issues (like grid_sampler) and not all operations are good for gpu (like inversing).

Articles

PyTorch 转 TensorRT 模型部署 - Dynamic Shape (Batch Size)

TensorRT 踩坑日志(Bullshit Diary)

2020-11-08:
Description:
Parsing ONNX file of GAN model.

ERROR: Failed to parse the ONNX file.
In node 60 (importInstanceNormalization): UNSUPPORTED_NODE: Assertion failed: !isDynamic(tensor_ptr->getDimensions()) && "InstanceNormalization does not support dynamic inputs!"

Solution:
Update to V7.2!!!

2020-11-11:
Description:
Got an error when allocating memory.

...
    self.stream = cuda.Stream()
pycuda._driver.LogicError: explicit_context_dependent failed: invalid device context - no currently active context?

Solution:

import pycuda.driver as cuda
## Add autoinit to init pycuda driver
import pycuda.autoinit

(pycuda.autoinit would take a some GPU memory)

2020-12-03:
Description:
transfer pytorch model to ONNX file

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

Solution:
find the error raising place from .../envs/pytorch1.6/lib/python3.7/site-packages/torch/onnx/

print(v.node())
# to get the error node location. 
# avg = F.avg_pool2d(feat32, feat32.size()[2:])
# add print(feat32.size()[2:]) to get the value
# set it to constant

2020-12-07:
Description:
Parsing ONNX in tensorrt

[TensorRT] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS_STATUS_SUCCESS
../rtSafe/cublas/cublasLtWrapper.cpp:279
Aborting...
[TensorRT] ERROR: ../rtSafe/cublas/cublasLtWrapper.cpp (279) - Assertion Error in getCublasLtHeuristic: 0 (cublasStatus == CUBLAS_STATUS_SUCCESS)

Solution:
This is caused by cublas LT 10.2 BUG. Solved by disabling cublasLT

trtexec --onnx=xxx.onnx --tacticSources=-cublasLt,+cublas --workspace=2048 --fp16 --saveEngine=xxx.engine

2020-12-08:
Description:
Allocate Buffer. Memory location bindings should be in order of binding index from engine. Sometimes, it is not the same as input/output order

2020-12-11:
Description:
When run tensorrt with saved engine

pycuda._driver.LogicError: cuMemcpyHtoDAsync failed: invalid argument

Solution: This may caused by input memory error. Check if the input dtype is Float64

2021-01-06:
Description:
I got this error when using Tensorrt and PyTorch together. I used PyTorch GPU calculation for the preprocessing.

[TensorRT] ERROR: safeContext.cpp (184) - Cudnn Error in configure: 7 (CUDNN_STATUS_MAPPING_ERROR)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception

I also found when splitting them to different processes, the error disappears. But in my case, I have a very large image data to be used in post-process, which would add extra latency during passing it between the processes.
I tried to use cupy instead of PyTorch and also got this error ;(

Solution: adding cuda context push and pop on the two ends of doing inference

cuda.init()
cuda_ctx = cuda.Device(gpu_id).make_context()
cuda_ctx.push()
... doing inferennce
cuda_ctx.pop()

2021-05-29:
Description:
When I deserializing an saved engine that I build from different place and it faild.

Solution: We first need to load all custom plugins shipped with TensorRT manually.

# Force init TensorRT plugins
trt.init_libnvinfer_plugins(None,'')
with open(engine_file_path, "rb") as f, \
        trt.Runtime(self.trt_logger) as runtime:
    engine = runtime.deserialize_cuda_engine(f.read())
return engine

2021-05-31:
Description: Pytorch to onnx success, but onnx to tensorrt engine faild,and throw an error

[05/28/2021-20:30:26] [I] [TRT] /training/colin/Github/TensorRT/parsers/onnx/ModelImporter.cpp:139: No importer registered for op: ScatterND. Attempting to import as plugin.
[05/28/2021-20:30:26] [I] [TRT] /training/colin/Github/TensorRT/parsers/onnx/builtin_op_importers.cpp:3775: Searching for plugin: ScatterND, plugin_version: 1, plugin_namespace: 
[05/28/2021-20:30:26] [E] [TRT] INVALID_ARGUMENT: getPluginCreator could not find plugin ScatterND version 1

Solution: ScatterND is for indexing, when you got operations like A[:, 0:2] = B. So the solotion is to substitute them with splits and concatinations.

2021-07-05:
Description: This is caused by BatchNorm1d, upstreamed by full connection layer

[TensorRT] ERROR: (Unnamed Layer* 11) [Shuffle]: at most one dimension may be inferred
ERROR: Failed to parse the ONNX file.
In node 1 (scaleHelper): UNSUPPORTED_NODE: Assertion failed: dims.nbDims == 4 || dims.nbDims == 5

Solution: From Pytorch documentation, the input of nn.BatchNorm1d could be (N, C, L) or (N, L). Unsqueeze the output of fc layer to (N, C, L), and it works.

...
fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
features = nn.BatchNorm1d(num_features)
...
x = fc(x)
x = x.unsqueeze(-1)
x = features(x)

2022-03-09:
Description: Failed to parse onnx file.

UNSUPPORTED_NODE: Assertion failed: (transformationMode == "asymmetric" || transformationMode == "pytorch_half_pixel" || transformationMode == "half_pixel") && "TensorRT only supports half pixel, pytorch half_pixel, and asymmetric tranformation mode for linear resizes when scales are provided!"

Solution: I got this error because of the function F.interpolate, due to scale_factor being converted to Double Type that is not allowed in TensorRT convertion. One of the solution is to calculating the final size instead of scale_factor during interpolation.

# For my case
F.interpolate(i, scale_factor=2, mode='bilinear', align_corners=True)
# Could be rewrite as (the input is x)
F.interpolate(x, size=[int(2 * x.shape[2]), int(2 * x.shape[3])], mode='bilinear', align_corners=True)

2022-03-21:
Description: Failed to generate onnx file.

RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible

Solution: This error is usually caused by non-static input size of the funcations like F.interpolate, F.avg_pool2d.

# For my case
# atten = F.avg_pool2d(feat, feat.size()[2:]) # non-static size
atten = F.avg_pool2d(feat, (16, 16)) # static size

About

Using TensorRT for Inference Model Deployment.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%