From 400df73d8fb1f1b447bad4df4a6f237c75f165e5 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Mon, 16 Sep 2019 13:30:12 -0700 Subject: [PATCH] TensorRT 6.0 ONNX Parser Release --- CMakeLists.txt | 10 +- ModelImporter.cpp | 15 + NvOnnxParser.h | 6 +- README.md | 36 +- ShapedWeights.cpp | 11 +- builtin_op_importers.cpp | 1012 ++++++++++++++++++++++---------------- onnx2trt_utils.cpp | 51 +- onnx2trt_utils.hpp | 225 ++++----- operators.md | 43 +- trt_utils.hpp | 11 +- 10 files changed, 797 insertions(+), 623 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 33e9c346..dede7853 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,7 @@ set(CMAKE_CXX_STANDARD 11) # Enable compiler warnings if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-deprecated-declarations -Wno-unused-function -Wno-unused-but-set-variable") endif() if ( MSVC ) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") @@ -59,9 +59,9 @@ set(RUNTIME_LINKER_SCRIPT ${ONNX2TRT_ROOT}/libnvonnxparser_runtime.version) #-------------------------------------------------- # Version information #-------------------------------------------------- -set(ONNX2TRT_MAJOR 0) -set(ONNX2TRT_MINOR 1) -set(ONNX2TRT_PATCH 0) +set(ONNX2TRT_MAJOR 6) +set(ONNX2TRT_MINOR 0) +set(ONNX2TRT_PATCH 1) #-------------------------------------------------- # Build configurations, global to all projects @@ -159,6 +159,7 @@ if(${CMAKE_VERSION} VERSION_LESS ${CMAKE_VERSION_THRESHOLD}) -lineinfo \ -g \ --expt-extended-lambda \ + -Xcompiler -Wno-deprecated-declarations \ ${GENCODES} \ ") @@ -174,6 +175,7 @@ else() -lineinfo \ -g \ --expt-extended-lambda \ + -Xcompiler -Wno-deprecated-declarations \ ${GENCODES} \ ") diff --git a/ModelImporter.cpp b/ModelImporter.cpp index cf58fbb6..887d4478 100644 --- a/ModelImporter.cpp +++ b/ModelImporter.cpp @@ -132,6 +132,21 @@ Status importInputs(ImporterContext* importer_ctx, ASSERT_INPUT(!tensors->count(input.name()), ErrorCode::kINVALID_GRAPH,input.name()); tensors->insert({input.name(), tensor}); } + + // According to the ONNX spec: initializers do not have to be specified as agraph input. + // In order for these initializers to be populated down to TRT, we need to add them to the tensors list. + for (auto initializer : initializer_map) + { + const std::string initializer_name = initializer.first; + if (!tensors->count(initializer_name)) + { + const auto& initializer_weight = *initializer.second; + ShapedWeights weights; + ASSERT(convert_onnx_weights(initializer_weight, &weights), ErrorCode::kUNSUPPORTED_NODE); + tensors->insert({initializer_name, weights}); + } + } + return Status::success(); } diff --git a/NvOnnxParser.h b/NvOnnxParser.h index 55f8cd1d..ee63aa9d 100644 --- a/NvOnnxParser.h +++ b/NvOnnxParser.h @@ -26,9 +26,9 @@ #include "NvInfer.h" #include "NvOnnxParserTypedefs.h" -#define NV_ONNX_PARSER_MAJOR 0 -#define NV_ONNX_PARSER_MINOR 1 -#define NV_ONNX_PARSER_PATCH 0 +#define NV_ONNX_PARSER_MAJOR 6 +#define NV_ONNX_PARSER_MINOR 0 +#define NV_ONNX_PARSER_PATCH 1 static const int NV_ONNX_PARSER_VERSION = ((NV_ONNX_PARSER_MAJOR * 10000) + (NV_ONNX_PARSER_MINOR * 100) + diff --git a/README.md b/README.md index 44753ea3..fda9e4ad 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,11 @@ See also the [TensorRT documentation](https://docs.nvidia.com/deeplearning/sdk/# ## Supported TensorRT Versions -Development on the Master branch is for the latest version of [TensorRT (5.1)](https://developer.nvidia.com/nvidia-tensorrt-download) +Development on the Master branch is for the latest version of [TensorRT 6.0](https://developer.nvidia.com/nvidia-tensorrt-download) -For versions < 5.1, clone and build from the [5.0 branch](https://github.com/onnx/onnx-tensorrt/tree/v5.0) +For version 5.1, clone and build from the [5.1 branch](https://github.com/onnx/onnx-tensorrt/tree/5.1) + +For version < 5.1, clone and build from the [5.0 branch](https://github.com/onnx/onnx-tensorrt/tree/v5.0) ## Supported Operators @@ -19,8 +21,8 @@ Current supported ONNX operators are found in the [operator support matrix](oper ### Dependencies - - [Protobuf](https://github.com/google/protobuf/releases) - - [TensorRT](https://developer.nvidia.com/tensorrt) + - [Protobuf >= 3.8.x](https://github.com/google/protobuf/releases) + - [TensorRT 6.0](https://developer.nvidia.com/tensorrt) ### Download the code Clone the code from GitHub. @@ -31,7 +33,7 @@ Clone the code from GitHub. The TensorRT-ONNX executables and libraries are built with CMAKE. Note by default CMAKE will tell the CUDA compiler generate code for the latest SM version. If you are using a GPU with a lower SM version you can specify which SMs to build for by using the optional `-DGPU_ARCHS` flag. For example, if you have a GTX 1080, you can specify `-DGPU_ARCHS="61"` to generate CUDA code specifically for that card. -See [here](https://developer.nvidia.com/cuda-gpus) for finding what maximum compute capability your specific GPU supports. +See [here](https://developer.nvidia.com/cuda-gpus) for the compute capability matrix for your specific GPU. mkdir build cd build @@ -56,6 +58,15 @@ See more usage information by running: onnx2trt -h +### Python modules +Python bindings for the ONNX-TensorRT parser are packaged in the shipped `.whl` files. Install them with + + pip install /python/tensorrt-6.0.1.5-cp27-none-linux_x86_64.whl + +TensorRT 6.0 supports ONNX release 1.5.0. Install it with: + + pip install onnx==1.5.0 + ## ONNX Python backend usage The TensorRT backend for ONNX can be used in Python as follows: @@ -84,22 +95,15 @@ libnvonnxparser_runtime.so, which has its C++ API declared in this header: NvOnnxParserRuntime.h -### Python modules -Python bindings for the ONNX-TensorRT parser in TensorRT versions >= 5.0 are packaged in the shipped `.whl` files. Install them with - - pip install /python/tensorrt-5.1.6.0-cp27-none-linux_x86_64.whl - -For earlier versions of TensorRT, the Python wrappers are built using SWIG. -Build the Python wrappers and modules by running: +Important typedefs required for parsing ONNX models are declared in this header: - python setup.py build - sudo python setup.py install + NvOnnxParserTypedefs.h ### Docker image Build the onnx_tensorrt Docker image by running: - cp /path/to/TensorRT-5.1.*.tar.gz . + cp /path/to/TensorRT-6.0.*.tar.gz . docker build -t onnx_tensorrt . ### Tests @@ -118,4 +122,4 @@ You can use `-v` flag to make output more verbose. ## Pre-trained models -Pre-trained models in ONNX format can be found at the [ONNX Model Zoo](https://github.com/onnx/models) +Pre-trained models in ONNX format can be found at the [ONNX Model Zoo](https://github.com/onnx/models) \ No newline at end of file diff --git a/ShapedWeights.cpp b/ShapedWeights.cpp index 47271458..e4c9c333 100644 --- a/ShapedWeights.cpp +++ b/ShapedWeights.cpp @@ -44,11 +44,16 @@ bool convertINT64(void* weightValues, const size_t nbWeights, std::vectorshape.nbDims == 0 ) { + if( this->values == nullptr && this->shape.nbDims == 0 ) + { return 0; - } else { + } + else + { + // TRT supports scalars, so 0D tensors should have a count of 1. size_t c = 1; - for( int i=0; ishape.nbDims; ++i ) { + for( int i=0; ishape.nbDims; ++i ) + { c *= this->shape.d[i]; } return c; diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index 655976d4..bb441859 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -29,13 +29,24 @@ #include "InstanceNormalization.hpp" #include // For std::iota -#include +#include // For std::min, std::max +#include // For std::stringstream namespace onnx2trt { namespace { -enum { BATCH_DIM = 0 }; +// In ONNX, the 0th dimension is the batch dimension. +// In our TRT network, we strip this dimension where applicable. +constexpr int BATCH_DIM = 0; + +inline nvinfer1::Dims makeDims(int nbDims, int val) +{ + nvinfer1::Dims dims; + dims.nbDims = nbDims; + std::fill_n(dims.d, nbDims, val); + return dims; +} // Returns false if the transpose does not require any data movement (i.e., it's equivalent to a reshape) bool is_transpose_required(nvinfer1::Dims const& shape, @@ -46,12 +57,12 @@ bool is_transpose_required(nvinfer1::Dims const& shape, int src_i = perm.order[dst_i]; if( shape.d[src_i] != 1 ) { if( src_i < prev_significant_dim ) { - return false; + return true; } prev_significant_dim = src_i; } } - return true; + return false; } // Note: perm should not include the batch dim @@ -65,17 +76,16 @@ transpose_tensor(IImporterContext* ctx, return nullptr; } nvinfer1::Dims shape = tensor.getDimensions(); - // Check if we need to transpose the data - if( !is_transpose_required(shape, perm) ) { + // If a transpose is required, add transpose property to the shuffle layer. + if( is_transpose_required(shape, perm) ) { layer->setFirstTranspose(perm); } - // Transpose can be simplified to be a reshape if no data re-ordering is required. + // Else, the transpose can be simplified to a reshape. else { nvinfer1::Dims new_shape; new_shape.nbDims = shape.nbDims; - for (int i = 0; i < new_shape.nbDims; i++) - { + for( int i=0; isetReshapeDimensions(new_shape); @@ -123,6 +133,7 @@ NodeImportResult unaryHelper(IImporterContext* ctx, return {{layer->getOutput(0)}}; } +// float* is the poor man's std::optional NodeImportResult activationHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, std::vector& inputs, nvinfer1::ActivationType op, float* alpha = nullptr, float* beta = nullptr) { @@ -158,9 +169,10 @@ addScale(IImporterContext* ctx, nvinfer1::Weights power) { nvinfer1::ITensor* tensor_ptr = &tensor_; nvinfer1::Dims dims = tensor_ptr->getDimensions(); -#if NV_TENSORRT_MAJOR >= 4 - bool need_to_expand_dims = (dims.nbDims != 3); - nvinfer1::Dims orig_shape = dims; + + // TRT supports 2D or 3D scale, others need expand dims and add shuffle layers. + const bool need_to_expand_dims = (dims.nbDims != 3 && dims.nbDims != 4); + const nvinfer1::Dims orig_shape = dims; if( need_to_expand_dims ) { // Expand or squash dims to 3D nvinfer1::Dims new_shape = dims; @@ -174,9 +186,9 @@ addScale(IImporterContext* ctx, ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); dims = tensor_ptr->getDimensions(); } -#endif // NV_TENSORRT_MAJOR >= 4 - ASSERT(dims.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); + const int nbSpatialDims = dims.nbDims - 1; + ASSERT((nbSpatialDims == 2 || nbSpatialDims == 3), ErrorCode::kUNSUPPORTED_NODE); // Fill in dtype for any unused (dummy) weights nvinfer1::DataType* dtype_ptr = nullptr; if( shift.count ) { @@ -196,18 +208,16 @@ addScale(IImporterContext* ctx, shift.type = *dtype_ptr; scale.type = *dtype_ptr; power.type = *dtype_ptr; - auto* layer = ctx->network()->addScale( - *tensor_ptr, mode, shift, scale, power); + auto* layer = ctx->network()->addScaleNd( + *tensor_ptr, mode, shift, scale, power, tensor_ptr->getDimensions().nbDims - nbSpatialDims - 1); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); tensor_ptr = layer->getOutput(0); -#if NV_TENSORRT_MAJOR >= 4 if( need_to_expand_dims ) { // Un-expand spatial dims back to 1D tensor_ptr = reshape_tensor(ctx, *tensor_ptr, orig_shape); ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); } -#endif // NV_TENSORRT_MAJOR >= 4 return {{tensor_ptr}}; } @@ -266,42 +276,48 @@ combineTensorsElementwise(IImporterContext* ctx, nvinfer1::ElementWiseOperation binary_op, bool legacy_binary_op_broadcasting=false) { ASSERT(!inputs.empty(), ErrorCode::kINVALID_NODE); - if (ctx->getOpsetVersion() < 7 && legacy_binary_op_broadcasting) { + if (ctx->getOpsetVersion() < 7 && legacy_binary_op_broadcasting) + { ASSERT(inputs.size() == 2, ErrorCode::kINTERNAL_ERROR); TRT_CHECK(applyLegacyBinaryOpBroadcasting(ctx, node, inputs[0], inputs[1])); } std::vector input_tensors; int ndim_max = -1; int tensors_ndim_max = -1; - for( auto input : inputs ) { + for (auto input : inputs) + { ndim_max = std::max(ndim_max, input.shape().nbDims); // Note: Tensor dims always exclude the batch dim, but weights may not - if( input.is_tensor() ) { + if(input.is_tensor()) + { tensors_ndim_max = std::max(tensors_ndim_max, input.shape().nbDims); } } - for( auto input : inputs ) { + for (auto input : inputs) + { nvinfer1::ITensor* tensor_ptr; -#if NV_TENSORRT_MAJOR < 4 - ASSERT(input.is_tensor(), ErrorCode::kUNSUPPORTED_NODE); - tensor_ptr = &input.tensor(); -#else - if( input.is_weights() ) { + if( input.is_weights() ) + { auto weights = input.weights(); // Note: TRT supports broadcasting, but ranks must match - if( input.shape().nbDims < ndim_max ) { + if( input.shape().nbDims < ndim_max ) + { weights.shape = expand_dims(weights.shape, ndim_max); } - if (weights.shape.nbDims == tensors_ndim_max + 1) { + if (weights.shape.nbDims == tensors_ndim_max + 1) + { // The weights contain a batch dim, which must be removed // Note: TRT Constant layer has implicit batch dim of 1 ASSERT(weights.shape.d[BATCH_DIM] == 1, ErrorCode::kUNSUPPORTED_NODE); weights.shape = remove_dim(weights.shape, BATCH_DIM); } + // Add a constant layer to convert weights to tensor. auto* layer = ctx->network()->addConstant(weights.shape, weights); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); tensor_ptr = layer->getOutput(0); - } else { + } + else + { tensor_ptr = &input.tensor(); // Support broadcasting for tensor inputs by expanding dimensions. if (tensor_ptr->getDimensions().nbDims != tensors_ndim_max) @@ -312,7 +328,6 @@ combineTensorsElementwise(IImporterContext* ctx, ASSERT(tensor_ptr->getDimensions().nbDims == tensors_ndim_max, ErrorCode::kUNSUPPORTED_NODE); } -#endif input_tensors.push_back(tensor_ptr); } nvinfer1::ITensor* combined = input_tensors.at(0); @@ -332,16 +347,17 @@ combineTensorsElementwise(IImporterContext* ctx, return {{combined}}; } -// Note: As of TRT 4, ElementWise + Constant is preferred over Scale layer -#if NV_TENSORRT_MAJOR < 4 Status check_broadcast_attrs(IImporterContext* ctx, OnnxAttrs const& attrs, - nvinfer1::Dims const& dims) { - if (ctx->getOpsetVersion() < 7) { + nvinfer1::Dims const& dims) +{ + if (ctx->getOpsetVersion() < 7) + { ASSERT(attrs.count("broadcast"), ErrorCode::kUNSUPPORTED_NODE); bool broadcast = attrs.get("broadcast"); ASSERT(broadcast || dims.nbDims == 1, ErrorCode::kINVALID_NODE); int axis = attrs.get("axis", -1); - TRT_CHECK(convert_axis(axis, dims.nbDims)); + int nbDims = dims.nbDims; + TRT_CHECK(convert_axis(axis, nbDims)); ASSERT(axis == 0, ErrorCode::kUNSUPPORTED_NODE); } return Status::success(); @@ -355,33 +371,38 @@ enum ScaleOp { NodeImportResult importScaleOp(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, - TensorOrWeights& input0, - TensorOrWeights& input1, + std::vector& inputs, ScaleOp op) { - auto* tensor_ptr = (input0.is_tensor() ? - &input0.tensor() : - &input1.tensor()); - auto weights = (input0.is_weights() ? - input0.weights() : - input1.weights()); + nvinfer1::ITensor* tensor_ptr = (inputs.at(0).is_tensor() ? + &inputs.at(0).tensor() : + &inputs.at(1).tensor()); + ShapedWeights weights = (inputs.at(0).is_weights() ? + inputs.at(0).weights() : + inputs.at(1).weights()); nvinfer1::Dims dims = tensor_ptr->getDimensions(); // Note: ONNX opset >= 7 uses Numpy-style broadcasting, so dims are padded // at the end with ones for broadcasting. weights.shape = squeeze_trailing_dims(weights.shape); - nvinfer1::ScaleMode mode = get_scale_mode(weights.shape); - if( mode == nvinfer1::ScaleMode::kELEMENTWISE ) { - // TODO: TRT doesn't support including the batch dim in elementwise, - // but we can't do a more specific assertion here yet because - // the input tensor's shape may have been padded to WAR TRT's - // shape issues. - ASSERT(get_shape_size(weights.shape) == get_shape_size(dims), - ErrorCode::kUNSUPPORTED_NODE); - } else if( mode == nvinfer1::ScaleMode::kCHANNEL ) { - OnnxAttrs attrs(node); - // TRT does not currently support full broadcasting - TRT_CHECK(check_broadcast_attrs(ctx, attrs, dims)); - ASSERT(weights.shape.d[0] == dims.d[0], - ErrorCode::kUNSUPPORTED_NODE); + nvinfer1::ScaleMode mode = get_scale_mode(weights.shape, dims); + if (mode == nvinfer1::ScaleMode::kELEMENTWISE) + { + nvinfer1::ElementWiseOperation elementwise_op = {}; + switch (op) + { + case kSHIFT: elementwise_op = nvinfer1::ElementWiseOperation::kSUM; break; + case kSCALE: elementwise_op = nvinfer1::ElementWiseOperation::kPROD; break; + case kPOWER: elementwise_op = nvinfer1::ElementWiseOperation::kPOW; break; + } + // If shapes do not entirely match up, an elementwise layer is needed instead + // to support full broadcasting. + if (get_shape_size(weights.shape) != get_shape_size(dims)) + { + return combineTensorsElementwise(ctx, + node, + inputs, + elementwise_op, + true); + } } nvinfer1::Weights shift_weights = {}; nvinfer1::Weights scale_weights = {}; @@ -394,7 +415,6 @@ NodeImportResult importScaleOp(IImporterContext* ctx, return addScale( ctx, *tensor_ptr, mode, shift_weights, scale_weights, power_weights); } -#endif // NV_TENSORRT_MAJOR < 4 } // namespace @@ -436,7 +456,7 @@ bool registerBuiltinOpImporter(std::string op, #define RETURN_FIRST_OUTPUT(layer) do { \ nvinfer1::ILayer* layer_ptr = layer; \ - ASSERT(layer_ptr != nullptr, ErrorCode::kUNSUPPORTED_NODE); \ + ASSERT(layer_ptr, ErrorCode::kUNSUPPORTED_NODE); \ return {{layer_ptr->getOutput(0)}}; \ } while(0) @@ -446,7 +466,6 @@ bool registerBuiltinOpImporter(std::string op, return {{output}}; \ } while(0) -#if NV_TENSORRT_MAJOR >= 4 // Helper for ArgMax/ArgMin NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, std::vector& inputs, nvinfer1::TopKOperation op) @@ -457,6 +476,7 @@ NodeImportResult argMinMaxHelper(IImporterContext* ctx, OnnxAttrs attrs(node); int keepdims = attrs.get("keepdims", 1); int axis = attrs.get("axis", 0); + int nbDims = tensor.getDimensions().nbDims; // Adjust axis to TensorRT format TRT_CHECK(convert_axis(axis, nbDims)); @@ -490,7 +510,6 @@ NodeImportResult argMinMaxHelper(IImporterContext* ctx, return {{squeezeLayer->getOutput(0)}}; } } -#endif // #if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(Abs) { return apply_unary_function(ctx, inputs.at(0), nvinfer1::UnaryOperation::kABS); @@ -507,12 +526,15 @@ DEFINE_BUILTIN_OP_IMPORTER(Acosh) } DEFINE_BUILTIN_OP_IMPORTER(Add) { - ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); + if (inputs.at(0).is_tensor() != inputs.at(1).is_tensor()) + { + return importScaleOp( + ctx, node, inputs, ScaleOp::kSHIFT); + } return combineTensorsElementwise( ctx, node, inputs, nvinfer1::ElementWiseOperation::kSUM, true); } -#if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(ArgMax) { return argMinMaxHelper(ctx, node, inputs, nvinfer1::TopKOperation::kMAX); @@ -522,8 +544,6 @@ DEFINE_BUILTIN_OP_IMPORTER(ArgMin) { return argMinMaxHelper(ctx, node, inputs, nvinfer1::TopKOperation::kMIN); } -#endif // #if NV_TENSORRT_MAJOR >= 4 - DEFINE_BUILTIN_OP_IMPORTER(Asin) { @@ -546,11 +566,8 @@ DEFINE_BUILTIN_OP_IMPORTER(Atanh) } DEFINE_BUILTIN_OP_IMPORTER(AveragePool) { - // TensorRT 5.1 only supports up to opset 9. - ASSERT(ctx->getOpsetVersion() < 10, ErrorCode::kUNSUPPORTED_NODE); nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); nvinfer1::Dims dims = tensor_ptr->getDimensions(); -#if NV_TENSORRT_MAJOR >= 4 bool need_to_expand_dims = (dims.nbDims == 2); if( need_to_expand_dims ) { // Expand spatial dims from 1D to 2D @@ -559,12 +576,26 @@ DEFINE_BUILTIN_OP_IMPORTER(AveragePool) { ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); dims = tensor_ptr->getDimensions(); } -#endif // NV_TENSORRT_MAJOR >= 4 + + // Support for opset10 ceil_mode + CeilingPoolDim ceilingPool; + // Ceiling and dialations added in opset 10 + if (ctx->getOpsetVersion() >= 10) + { + OnnxAttrs attrs(node); + const auto ceil_mode = attrs.get("ceil_mode", 0); + const auto dilations = attrs.get>("dilations", std::vector (2, 1)); + for(size_t i = 0; i < dilations.size(); i++) ASSERT(dilations[i] == 1, ErrorCode::kUNSUPPORTED_NODE); // Do not support pooling dilations currently + if (ceil_mode != 0) // Need to set pooling formula to use ceiling instead of floor + { + ctx->network()->setPoolingOutputDimensionsFormula(&ceilingPool); + } + } + ASSERT(dims.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); nvinfer1::DimsHW kernel_size(1, 1), strides(1, 1), beg_padding(0, 0), end_padding(0, 0); nvinfer1::PaddingMode paddingMode; - get_kernel_params(node, get_DimsHW_from_CHW(dims), - &kernel_size, &strides, &beg_padding, &end_padding, paddingMode); + get_kernel_params(node, &kernel_size, &strides, &beg_padding, &end_padding, paddingMode); nvinfer1::IPoolingLayer* pooling_layer = ctx->network()->addPooling( *tensor_ptr, nvinfer1::PoolingType::kAVERAGE, kernel_size); nvinfer1::ILayer* layer = pooling_layer; @@ -596,14 +627,12 @@ DEFINE_BUILTIN_OP_IMPORTER(AveragePool) { } tensor_ptr = layer->getOutput(0); dims = tensor_ptr->getDimensions(); -#if NV_TENSORRT_MAJOR >= 4 if( need_to_expand_dims ) { // Un-expand spatial dims back to 1D nvinfer1::Dims new_shape{2, {dims.d[0], dims.d[1]}}; tensor_ptr = reshape_tensor(ctx, *tensor_ptr, new_shape); ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); } -#endif // NV_TENSORRT_MAJOR >= 4 return {{tensor_ptr}}; } @@ -653,13 +682,9 @@ DEFINE_BUILTIN_OP_IMPORTER(BatchNormalization) { combined_bias_weights, combined_scale_weights, {}); } -DEFINE_BUILTIN_OP_IMPORTER(Ceil) { - return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kCEIL); -} - DEFINE_BUILTIN_OP_IMPORTER(Cast) { - // Get input node. OnnxAttrs attrs(node); + // Get data type to cast to. auto cast_dtype = attrs.get("to"); auto * tensor_ptr = &convertToTensor(inputs.at(0), ctx); auto trt_dtype = tensor_ptr->getType(); @@ -672,6 +697,10 @@ DEFINE_BUILTIN_OP_IMPORTER(Cast) { RETURN_FIRST_OUTPUT(layer); } +DEFINE_BUILTIN_OP_IMPORTER(Ceil) { + return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kCEIL); +} + DEFINE_BUILTIN_OP_IMPORTER(Clip) { OnnxAttrs attrs(node); // beta is the upper bound. @@ -683,15 +712,14 @@ DEFINE_BUILTIN_OP_IMPORTER(Clip) { DEFINE_BUILTIN_OP_IMPORTER(Concat) { std::vector tensors; for( auto& input : inputs ) { -#if NV_TENSORRT_MAJOR >= 4 - ASSERT(input.is_tensor() && input.tensor().getType() != nvinfer1::DataType::kINT32, + nvinfer1::ITensor* tensor_ptr = &convertToTensor(input, ctx); + ASSERT(input.tensor().getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE); -#endif // NV_TENSORRT_MAJOR >= 4 - tensors.push_back(&convertToTensor(input, ctx)); + tensors.push_back(tensor_ptr); } OnnxAttrs attrs(node); - int nbDims = inputs.at(0).shape().nbDims; int axis = attrs.get("axis"); + int nbDims = inputs.at(0).shape().nbDims; TRT_CHECK(convert_axis(axis, nbDims)); auto* layer = ctx->network()->addConcatenation(tensors.data(), tensors.size()); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); @@ -700,18 +728,18 @@ DEFINE_BUILTIN_OP_IMPORTER(Concat) { } DEFINE_BUILTIN_OP_IMPORTER(Constant) { - // TODO: This silently fails if the dtype is not supported - OnnxAttrs attrs(node); - return {{attrs.get("value")}}; + // TODO: This silently fails if the dtype is not supported + OnnxAttrs attrs(node); + return {{attrs.get("value")}}; } DEFINE_BUILTIN_OP_IMPORTER(Conv) { // Convolution weights must be an initializer ASSERT(inputs.at(1).is_weights(), ErrorCode::kUNSUPPORTED_NODE); + nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); auto kernel_weights = inputs.at(1).weights(); nvinfer1::Dims dims = tensor_ptr->getDimensions(); - #if NV_TENSORRT_MAJOR >= 4 bool need_to_expand_dims = (dims.nbDims == 2); if( need_to_expand_dims ) { // Expand spatial dims from 1D to 2D @@ -724,9 +752,10 @@ DEFINE_BUILTIN_OP_IMPORTER(Conv) { kernel_weights.shape.nbDims = 4; kernel_weights.shape.d[3] = 1; } - #endif // NV_TENSORRT_MAJOR >= 4 - ASSERT(dims.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); - ASSERT(kernel_weights.shape.nbDims == 4, ErrorCode::kUNSUPPORTED_NODE); + const int nbSpatialDims = dims.nbDims - 1; + ASSERT((nbSpatialDims == 2 && kernel_weights.shape.nbDims == 4) || + (nbSpatialDims == 3 && kernel_weights.shape.nbDims == 5), ErrorCode::kUNSUPPORTED_NODE); //NOW only support 2D/3D + nvinfer1::Weights bias_weights; if( inputs.size() == 3 ) { ASSERT(inputs.at(2).is_weights(), ErrorCode::kUNSUPPORTED_NODE); @@ -737,42 +766,45 @@ DEFINE_BUILTIN_OP_IMPORTER(Conv) { } else { bias_weights = ShapedWeights::empty(kernel_weights.type); } - nvinfer1::DimsHW kernel_size; - kernel_size.h() = kernel_weights.shape.d[2]; - kernel_size.w() = kernel_weights.shape.d[3]; - nvinfer1::DimsHW strides(1, 1); - nvinfer1::DimsHW beg_padding(0, 0), end_padding(0, 0); - nvinfer1::DimsHW dilations(1, 1); + nvinfer1::Dims kernel_size; + kernel_size.nbDims = nbSpatialDims; + for(int i = 1; i <= nbSpatialDims; ++i){ + kernel_size.d[nbSpatialDims - i] = kernel_weights.shape.d[kernel_weights.shape.nbDims - i]; + } + nvinfer1::Dims strides = makeDims(nbSpatialDims, 1); + nvinfer1::Dims beg_padding = makeDims(nbSpatialDims, 0); + nvinfer1::Dims end_padding = makeDims(nbSpatialDims, 0); + nvinfer1::Dims dilations = makeDims(nbSpatialDims, 1); nvinfer1::PaddingMode paddingMode; - get_kernel_params(node, get_DimsHW_from_CHW(dims), &kernel_size, - &strides, &beg_padding, &end_padding, paddingMode, &dilations); - ASSERT(kernel_size.h() == kernel_weights.shape.d[2], ErrorCode::kINVALID_NODE); - ASSERT(kernel_size.w() == kernel_weights.shape.d[3], ErrorCode::kINVALID_NODE); + get_kernel_params(node, &kernel_size, &strides, &beg_padding, &end_padding, paddingMode, &dilations); + + for(int i = 1; i <= nbSpatialDims; ++i){ + ASSERT(kernel_size.d[nbSpatialDims - i] == kernel_weights.shape.d[kernel_weights.shape.nbDims - i], ErrorCode::kUNSUPPORTED_NODE); + } + int nchan = dims.d[0]; int noutput = kernel_weights.shape.d[0]; // Note: Weights order is KCRS - nvinfer1::IConvolutionLayer* layer = ctx->network()->addConvolution( + nvinfer1::IConvolutionLayer* layer = ctx->network()->addConvolutionNd( *tensor_ptr, noutput, kernel_size, kernel_weights, bias_weights); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - layer->setStride(strides); + layer->setStrideNd(strides); layer->setPaddingMode(paddingMode); layer->setPrePadding(beg_padding); layer->setPostPadding(end_padding); - layer->setDilation(dilations); + layer->setDilationNd(dilations); OnnxAttrs attrs(node); int ngroup = attrs.get("group", 1); ASSERT(kernel_weights.shape.d[1] * ngroup == nchan, ErrorCode::kINVALID_NODE); layer->setNbGroups(ngroup); tensor_ptr = layer->getOutput(0); dims = tensor_ptr->getDimensions(); - #if NV_TENSORRT_MAJOR >= 4 if( need_to_expand_dims ) { // Un-expand spatial dims back to 1D nvinfer1::Dims new_shape{2, {dims.d[0], dims.d[1]}}; tensor_ptr = reshape_tensor(ctx, *tensor_ptr, new_shape); ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); } - #endif // NV_TENSORRT_MAJOR >= 4 return {{tensor_ptr}}; } @@ -782,7 +814,6 @@ DEFINE_BUILTIN_OP_IMPORTER(ConvTranspose) { nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); auto kernel_weights = inputs.at(1).weights(); nvinfer1::Dims dims = tensor_ptr->getDimensions(); -#if NV_TENSORRT_MAJOR >= 4 bool need_to_expand_dims = (dims.nbDims == 2); if( need_to_expand_dims ) { // Expand spatial dims from 1D to 2D @@ -795,9 +826,10 @@ DEFINE_BUILTIN_OP_IMPORTER(ConvTranspose) { kernel_weights.shape.nbDims = 4; kernel_weights.shape.d[3] = 1; } -#endif // NV_TENSORRT_MAJOR >= 4 - ASSERT(dims.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); - ASSERT(kernel_weights.shape.nbDims == 4, ErrorCode::kUNSUPPORTED_NODE); + const int nbSpatialDims = dims.nbDims - 1; + // TensorRT supports 2D and 3D deconvolutions + ASSERT((nbSpatialDims == 2 && kernel_weights.shape.nbDims == 4) || + (nbSpatialDims == 3 && kernel_weights.shape.nbDims == 5), ErrorCode::kUNSUPPORTED_NODE); nvinfer1::Weights bias_weights; if( inputs.size() == 3 ) { ASSERT(inputs.at(2).is_weights(), ErrorCode::kUNSUPPORTED_NODE); @@ -810,58 +842,58 @@ DEFINE_BUILTIN_OP_IMPORTER(ConvTranspose) { bias_weights = ShapedWeights::empty(kernel_weights.type); } OnnxAttrs attrs(node); - nvinfer1::DimsHW input_shape = get_DimsHW_from_CHW(dims); - nvinfer1::DimsHW output_shape; + nvinfer1::Dims output_shape; if( attrs.count("output_shape") ) { - output_shape = attrs.get("output_shape"); + output_shape = attrs.get("output_shape"); } else { - ASSERT(attrs.get("auto_pad", std::string("VALID")) == "VALID", + ASSERT(attrs.get("auto_pad", std::string("VALID")) == "VALID" || attrs.get("auto_pad", std::string("NOTSET")) == "NOTSET" , ErrorCode::kINVALID_NODE); } - nvinfer1::DimsHW kernel_size; - kernel_size.h() = kernel_weights.shape.d[2]; - kernel_size.w() = kernel_weights.shape.d[3]; - nvinfer1::DimsHW strides(1, 1); - nvinfer1::DimsHW beg_padding(0, 0), end_padding(0, 0); - nvinfer1::DimsHW dilations(1, 1); + nvinfer1::Dims kernel_size; + kernel_size.nbDims = nbSpatialDims; + for(int i = 1; i <= nbSpatialDims; ++i){ + kernel_size.d[nbSpatialDims - i] = kernel_weights.shape.d[kernel_weights.shape.nbDims - i]; + } + nvinfer1::Dims strides = makeDims(nbSpatialDims, 1); + nvinfer1::Dims beg_padding = makeDims(nbSpatialDims, 0); + nvinfer1::Dims end_padding = makeDims(nbSpatialDims, 0); + nvinfer1::Dims dilations = makeDims(nbSpatialDims, 1); nvinfer1::PaddingMode paddingMode; // Note: output_shape/input_shape are swapped here so that the padding // calculations operate as if it were a regular forward convolution. - get_kernel_params(node, output_shape, - &kernel_size, &strides, - &beg_padding, &end_padding, paddingMode, &dilations, &input_shape); - ASSERT(kernel_size.h() == kernel_weights.shape.d[2], ErrorCode::kINVALID_NODE); - ASSERT(kernel_size.w() == kernel_weights.shape.d[3], ErrorCode::kINVALID_NODE); - ASSERT(dims.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); + get_kernel_params(node, &kernel_size, &strides, &beg_padding, &end_padding, paddingMode, &dilations); + for(int i = 1; i <= nbSpatialDims; ++i){ + ASSERT(kernel_size.d[nbSpatialDims - i] == kernel_weights.shape.d[kernel_weights.shape.nbDims - i], ErrorCode::kUNSUPPORTED_NODE); + ASSERT(dilations.d[nbSpatialDims - i] == 1, ErrorCode::kUNSUPPORTED_GRAPH); + } + int nchan = dims.d[0]; int ngroup = attrs.get("group", 1); int noutput = kernel_weights.shape.d[1] * ngroup; // Note: Weights order is CKRS - nvinfer1::IDeconvolutionLayer* deconv_layer = ctx->network()->addDeconvolution( + nvinfer1::IDeconvolutionLayer* deconv_layer = ctx->network()->addDeconvolutionNd( *tensor_ptr, noutput, kernel_size, kernel_weights, bias_weights); nvinfer1::ILayer* layer = deconv_layer; ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - deconv_layer->setStride(strides); + deconv_layer->setStrideNd(strides); if( !attrs.count("output_shape") && attrs.count("output_padding") ) { - auto output_padding = attrs.get("output_padding"); - end_padding.h() -= output_padding.h(); - end_padding.w() -= output_padding.w(); + auto output_padding = attrs.get("output_padding"); + for (int i = 1; i <= nbSpatialDims; ++i){ + end_padding.d[nbSpatialDims - i] -= end_padding.d[nbSpatialDims - i]; + } } deconv_layer->setPaddingMode(paddingMode); deconv_layer->setPrePadding(beg_padding); deconv_layer->setPostPadding(end_padding); - ASSERT(dilations.h() == 1 && dilations.w() == 1, ErrorCode::kUNSUPPORTED_NODE); ASSERT(kernel_weights.shape.d[0] == nchan, ErrorCode::kINVALID_NODE); deconv_layer->setNbGroups(ngroup); tensor_ptr = layer->getOutput(0); dims = tensor_ptr->getDimensions(); -#if NV_TENSORRT_MAJOR >= 4 if( need_to_expand_dims ) { // Un-expand spatial dims back to 1D nvinfer1::Dims new_shape{2, {dims.d[0], dims.d[1]}}; tensor_ptr = reshape_tensor(ctx, *tensor_ptr, new_shape); ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); } - #endif // NV_TENSORRT_MAJOR >= 4 return {{tensor_ptr}}; } @@ -875,8 +907,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Cosh) return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kCOSH); } - -#if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(DepthToSpace) { nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(*tensor_ptr); @@ -914,26 +944,25 @@ DEFINE_BUILTIN_OP_IMPORTER(DepthToSpace) { ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); return {{tensor_ptr}}; } -#endif // NV_TENSORRT_MAJOR >= 4 DECLARE_BUILTIN_OP_IMPORTER(Mul); DEFINE_BUILTIN_OP_IMPORTER(Div) { ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); return combineTensorsElementwise( - ctx, node, inputs, nvinfer1::ElementWiseOperation::kDIV, true); + ctx, node, inputs, nvinfer1::ElementWiseOperation::kDIV, true); } DEFINE_BUILTIN_OP_IMPORTER(Dropout) { - // TensorRT 5.1 only supports up to opset 9. - ASSERT(ctx->getOpsetVersion() < 10, ErrorCode::kUNSUPPORTED_NODE); int noutputs = node.output().size(); if (noutputs == 1) { RETURN_IDENTITY(inputs.at(0)); } - else + else { - // Return both Dropout outputs: (output + mask) + // Error if opset version >= 10 as boolean tensors are not supported + ASSERT(ctx->getOpsetVersion() < 10, ErrorCode::kUNSUPPORTED_NODE); + // Add identity layer twice for both Dropout outputs: (output + mask) std::vector outputs; outputs.push_back(identity(ctx,inputs.at(0))); outputs.push_back(identity(ctx,inputs.at(0))); @@ -958,19 +987,17 @@ DEFINE_BUILTIN_OP_IMPORTER(Flatten) { // operation, because we can't remove or merge into the batch dim. ASSERT(axis == 1, ErrorCode::kUNSUPPORTED_NODE); nvinfer1::Dims dims = inputs.at(0).shape(); - nvinfer1::ITensor* tensor_ptr; -#if NV_TENSORRT_MAJOR < 4 - // Note: TRT3 requires that the shape remain 3D (CHW) - tensor_ptr = flatten_tensor(ctx, convertToTensor(inputs.at(0), ctx)); -#else // NV_TENSORRT_MAJOR >= 4 + nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); nvinfer1::Dims new_shape{1, {(int)get_shape_size(dims)}}; - tensor_ptr = reshape_tensor(ctx, convertToTensor(inputs.at(0), ctx), new_shape); -#endif // NV_TENSORRT_MAJOR >= 4 + tensor_ptr = reshape_tensor(ctx, *tensor_ptr, new_shape); ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); return {{tensor_ptr}}; } -#if NV_TENSORRT_MAJOR >= 4 +DEFINE_BUILTIN_OP_IMPORTER(Floor) { + return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kFLOOR); +} + DEFINE_BUILTIN_OP_IMPORTER(Gather) { nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx); nvinfer1::ITensor& indices = convertToTensor(inputs.at(1), ctx); @@ -980,13 +1007,9 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather) { TRT_CHECK(convert_axis(axis, nbDims)); RETURN_FIRST_OUTPUT(ctx->network()->addGather(data, indices, axis)); } -#endif // NV_TENSORRT_MAJOR >= 4 -DEFINE_BUILTIN_OP_IMPORTER(Floor) { - return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kFLOOR); -} - -DEFINE_BUILTIN_OP_IMPORTER(Gemm) { +DEFINE_BUILTIN_OP_IMPORTER(Gemm) +{ OnnxAttrs attrs(node); float alpha = attrs.get("alpha", 1.f); float beta = attrs.get("beta", 1.f); @@ -1037,11 +1060,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) { inputB = &inputs.at(1).tensor(); } - if (ctx->getOpsetVersion() < 7) - { - ASSERT(attrs.get("broadcast", false), ErrorCode::kUNSUPPORTED_NODE); - } - nvinfer1::ITensor* inputASqueezed = &inputA; nvinfer1::Dims newDims = squeeze_trailing_dims(inputA.getDimensions()); // When A has more than 2 dimensions, it needs to be flattened. @@ -1059,21 +1077,24 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) { constexpr auto getMatrixOp = [] (const nvinfer1::ITensor& input, bool transpose) { - return (input.getDimensions().nbDims == 1) ? - nvinfer1::MatrixOperation::kVECTOR : - (transpose) ? - nvinfer1::MatrixOperation::kTRANSPOSE : - nvinfer1::MatrixOperation::kNONE; + if (input.getDimensions().nbDims == 1) + { + return nvinfer1::MatrixOperation::kVECTOR; + } + else if (transpose) + { + return nvinfer1::MatrixOperation::kTRANSPOSE; + } + return nvinfer1::MatrixOperation::kNONE; }; nvinfer1::MatrixOperation opA = getMatrixOp(*inputASqueezed, transA); nvinfer1::MatrixOperation opB = getMatrixOp(*inputB, transB); - if (opA == nvinfer1::MatrixOperation::kVECTOR && opB == nvinfer1::MatrixOperation::kVECTOR) - { - ASSERT(inputASqueezed->getDimensions() == inputB->getDimensions(), ErrorCode::kUNSUPPORTED_NODE); - } + { + ASSERT(inputASqueezed->getDimensions() == inputB->getDimensions(), ErrorCode::kUNSUPPORTED_NODE); + } nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(*inputASqueezed, opA, *inputB, opB); nvinfer1::ITensor* matmulTensor = matmul->getOutput(0); @@ -1098,8 +1119,15 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) { biasTensor = scaledBias->getOutput(0); } + // A*B may be lower rank than C in TRT, so need to squeeze C. + if (ctx->getOpsetVersion() < 7 && !attrs.get("broadcast", false)) + { + nvinfer1::Dims squeezeDims = squeeze_leading_dims(biasTensor->getDimensions()); + biasTensor = reshape_tensor(ctx, *biasTensor, squeezeDims); + } broadcast_tensors(ctx, matmulTensor, biasTensor); - RETURN_FIRST_OUTPUT(ctx->network()->addElementWise(*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM)); + nvinfer1::IElementWiseLayer* biasAdd = ctx->network()->addElementWise(*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM); + return {{biasAdd->getOutput(0)}}; } DEFINE_BUILTIN_OP_IMPORTER(GlobalAveragePool) { @@ -1154,14 +1182,14 @@ DEFINE_BUILTIN_OP_IMPORTER(ImageScaler) { } DEFINE_BUILTIN_OP_IMPORTER(InstanceNormalization) { - // Scales and bias must be an initializer + // Scale and bias must initializers ASSERT(inputs.at(1).is_weights(), ErrorCode::kUNSUPPORTED_NODE); ASSERT(inputs.at(2).is_weights(), ErrorCode::kUNSUPPORTED_NODE); auto scale_weights = inputs.at(1).weights(); auto bias_weights = inputs.at(2).weights(); OnnxAttrs attrs(node); float epsilon = attrs.get("epsilon", 1e-5f); - // Lock maximum epislon value to 1e-4f. + // Lock maximum epsilon value to 1e-4f epsilon = std::max(epsilon, 1e-4f); RETURN_FIRST_OUTPUT( ctx->addPluginV2( @@ -1169,7 +1197,8 @@ DEFINE_BUILTIN_OP_IMPORTER(InstanceNormalization) { {&convertToTensor(inputs.at(0), ctx)})); } -DEFINE_BUILTIN_OP_IMPORTER(LeakyRelu) { +DEFINE_BUILTIN_OP_IMPORTER(LeakyRelu) +{ OnnxAttrs attrs(node); float alpha = attrs.get("alpha", 0.01f); return activationHelper(ctx, node, inputs, nvinfer1::ActivationType::kLEAKY_RELU, &alpha); @@ -1215,7 +1244,8 @@ DEFINE_BUILTIN_OP_IMPORTER(MatMul) { nvinfer1::MatrixOperation opA = getMatrixOp(inputA); nvinfer1::MatrixOperation opB = getMatrixOp(inputB); - RETURN_FIRST_OUTPUT(ctx->network()->addMatrixMultiply(inputA, opA, inputB, opB)); + nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(inputA, opA, inputB, opB); + return {{matmul->getOutput(0)}}; } DEFINE_BUILTIN_OP_IMPORTER(Max) { @@ -1224,12 +1254,10 @@ DEFINE_BUILTIN_OP_IMPORTER(Max) { } DEFINE_BUILTIN_OP_IMPORTER(MaxPool) { - // TensorRT 5.1 only supports up to opset 9. - ASSERT(ctx->getOpsetVersion() < 10, ErrorCode::kUNSUPPORTED_NODE); nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); nvinfer1::Dims dims = tensor_ptr->getDimensions(); ASSERT(dims.nbDims >= 2, ErrorCode::kINVALID_NODE); -#if NV_TENSORRT_MAJOR >= 4 + bool need_to_expand_dims = (dims.nbDims == 2); if( need_to_expand_dims ) { // Expand spatial dims from 1D to 2D @@ -1238,33 +1266,50 @@ DEFINE_BUILTIN_OP_IMPORTER(MaxPool) { ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); dims = tensor_ptr->getDimensions(); } -#endif // NV_TENSORRT_MAJOR >= 4 - ASSERT(dims.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); - nvinfer1::DimsHW kernel_size(1, 1), strides(1, 1), beg_padding(0, 0), end_padding(0, 0); + + // Support for opset10 ceil_mode + CeilingPoolDim ceilingPool; + // Ceiling and dialations added in opset 10 + if (ctx->getOpsetVersion() >= 10) + { + OnnxAttrs attrs(node); + const auto ceil_mode = attrs.get("ceil_mode", 0); + const auto dilations = attrs.get>("dilations", std::vector (2, 1)); + for(size_t i = 0; i < dilations.size(); i++) ASSERT(dilations[i] == 1, ErrorCode::kUNSUPPORTED_NODE); // Do not support pooling dilations currently + if (ceil_mode != 0) // Need to set pooling formula to use ceiling instead of floor + { + ctx->network()->setPoolingOutputDimensionsFormula(&ceilingPool); + } + } + + int nbSpatialDims = dims.nbDims - 1; + ASSERT(nbSpatialDims == 2 || nbSpatialDims == 3, ErrorCode::kUNSUPPORTED_NODE); + nvinfer1::Dims kernel_size = makeDims(nbSpatialDims, 1); + nvinfer1::Dims strides = makeDims(nbSpatialDims, 1); + nvinfer1::Dims beg_padding = makeDims(nbSpatialDims, 0); + nvinfer1::Dims end_padding = makeDims(nbSpatialDims, 0); nvinfer1::PaddingMode paddingMode; - get_kernel_params(node, get_DimsHW_from_CHW(dims), + get_kernel_params(node, &kernel_size, &strides, &beg_padding, &end_padding, paddingMode); - nvinfer1::IPoolingLayer* layer = ctx->network()->addPooling( + nvinfer1::IPoolingLayer* layer = ctx->network()->addPoolingNd( *tensor_ptr, nvinfer1::PoolingType::kMAX, kernel_size); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - layer->setStride(strides); + layer->setStrideNd(strides); layer->setPaddingMode(paddingMode); layer->setPrePadding(beg_padding); layer->setPostPadding(end_padding); tensor_ptr = layer->getOutput(0); dims = tensor_ptr->getDimensions(); -#if NV_TENSORRT_MAJOR >= 4 + if( need_to_expand_dims ) { // Un-expand spatial dims back to 1D nvinfer1::Dims new_shape{2, {dims.d[0], dims.d[1]}}; tensor_ptr = reshape_tensor(ctx, *tensor_ptr, new_shape); ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE); } -#endif // NV_TENSORRT_MAJOR >= 4 return {{tensor_ptr}}; } -#if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(Mean) { auto sum_result = combineTensorsElementwise( ctx, node, inputs, nvinfer1::ElementWiseOperation::kSUM); @@ -1289,7 +1334,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Mean) { ctx->network()->addElementWise( sum_tensor, scale_constant, nvinfer1::ElementWiseOperation::kPROD)); } -#endif // NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(Min) { return combineTensorsElementwise( @@ -1297,7 +1341,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Min) { } DEFINE_BUILTIN_OP_IMPORTER(Mul) { - ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); + // Explicit precision networks need scale layer over elementwise + if (inputs.at(0).is_tensor() != inputs.at(1).is_tensor()) + { + return importScaleOp( + ctx, node, inputs, ScaleOp::kSCALE); + } return combineTensorsElementwise( ctx, node, inputs, nvinfer1::ElementWiseOperation::kPROD, true); } @@ -1348,58 +1397,80 @@ DEFINE_BUILTIN_OP_IMPORTER(ParametricSoftplus) { } DEFINE_BUILTIN_OP_IMPORTER(Pow) { + + if (inputs.at(0).is_tensor() != inputs.at(1).is_tensor()) + { + return importScaleOp( + ctx, node, inputs, ScaleOp::kPOWER); + } ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); return combineTensorsElementwise( - ctx, node, inputs, nvinfer1::ElementWiseOperation::kPOW, true); -} - -// TODO: Prelu is currently ONLY supported with a constant scale factor, making it -// identcal with LeakyRelu. Removing the op from the registry until it is fully supported. - -// DEFINE_BUILTIN_OP_IMPORTER(PRelu) { -// ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE); -// ASSERT(inputs.at(1).is_weights(), ErrorCode::kUNSUPPORTED_NODE); -// ShapedWeights weights = inputs.at(1).weights(); -// ASSERT(weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT, -// ErrorCode::kUNSUPPORTED_NODE); -// // TODO: Add support for per-channel scale factor -// nvinfer1::Dims scalar_shape{1, {1}}; -// ASSERT(weights.shape == scalar_shape, ErrorCode::kUNSUPPORTED_NODE); -// float alpha = *reinterpret_cast(weights.values); -// RETURN_FIRST_OUTPUT( -// ctx->addPluginV2( -// new FancyActivationPlugin(FancyActivationPlugin::LEAKY_RELU, alpha), -// {&inputs.at(0).tensor()})); -// } + ctx, node, inputs, nvinfer1::ElementWiseOperation::kPOW, true); +} + +DEFINE_BUILTIN_OP_IMPORTER(PRelu) { + ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); + nvinfer1::ITensor& input = convertToTensor(inputs.at(0), ctx); + const auto& shape1 = inputs.at(0).shape(); + nvinfer1::ITensor* slopes{}; + if (inputs.at(1).is_tensor()) + { + const auto& shape2 = inputs.at(1).shape(); + ASSERT(shape1.nbDims == shape2.nbDims, ErrorCode::kUNSUPPORTED_NODE); + for (int i = 0; i < shape1.nbDims; ++i) + { + ASSERT(shape1.d[i] == shape2.d[i] || shape2.d[i] == 1, ErrorCode::kUNSUPPORTED_NODE); + } + slopes = &convertToTensor(inputs.at(1), ctx); + } + else + { + auto weights = inputs.at(1).weights(); + if (inputs.at(1).shape().nbDims < shape1.nbDims) + { + weights.shape = expand_dims(weights.shape, shape1.nbDims); + } + else if (inputs.at(1).shape().nbDims > shape1.nbDims) + { + weights.shape = remove_dim(weights.shape, BATCH_DIM); + } + auto constantLayer = ctx->network()->addConstant(weights.shape, weights); + ASSERT(constantLayer, ErrorCode::kUNSUPPORTED_NODE); + slopes = constantLayer->getOutput(0); + } + ASSERT(input.getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE); + ASSERT(slopes->getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE); + RETURN_FIRST_OUTPUT(ctx->network()->addParametricReLU(input, *slopes)); +} DEFINE_BUILTIN_OP_IMPORTER(Reciprocal) { return apply_unary_function(ctx, inputs.at(0), nvinfer1::UnaryOperation::kRECIP); } -#if NV_TENSORRT_MAJOR >= 4 NodeImportResult reduceTensor(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, TensorOrWeights input, nvinfer1::ReduceOperation operation) { - nvinfer1::ITensor& tensor = convertToTensor(input, ctx); - OnnxAttrs attrs(node); - bool keepdims = attrs.get("keepdims", 1); - int ndim = tensor.getDimensions().nbDims; - std::vector axes; - if( attrs.count("axes") ) { - axes = attrs.get>("axes"); - } else { - axes.resize(ndim); - std::iota(axes.begin(), axes.end(), 0); - } - uint32_t axis_mask = 0; - for( int axis : axes ) { - // Adjust axis to TensorRT format - TRT_CHECK(convert_axis(axis, ndim)); - axis_mask |= 1 << axis; - } - RETURN_FIRST_OUTPUT( - ctx->network()->addReduce(tensor, operation, axis_mask, keepdims)); + nvinfer1::ITensor& tensor = convertToTensor(input, ctx); + OnnxAttrs attrs(node); + bool keepdims = attrs.get("keepdims", 1); + int ndim = tensor.getDimensions().nbDims; + std::vector axes; + if( attrs.count("axes") ) { + axes = attrs.get>("axes"); + } else { + axes.resize(ndim); + std::iota(axes.begin(), axes.end(), 0); + } + + uint32_t axisMask = 0; + for (int axis : axes) { + TRT_CHECK(convert_axis(axis, ndim)); + axisMask |= 1 << axis; + } + + RETURN_FIRST_OUTPUT( + ctx->network()->addReduce(tensor, operation, axisMask, keepdims)); } DEFINE_BUILTIN_OP_IMPORTER(ReduceL1) { NodeImportResult abs_result = apply_unary_function( @@ -1463,12 +1534,9 @@ DEFINE_BUILTIN_OP_IMPORTER(ReduceSumSquare) { ctx, node, sqr_tensor_ptr, nvinfer1::ReduceOperation::kSUM); } -#endif // NV_TENSORRT_MAJOR >= 4 - -DEFINE_BUILTIN_OP_IMPORTER(Relu) { - RETURN_FIRST_OUTPUT( - ctx->network()->addActivation( - inputs.at(0).tensor(), nvinfer1::ActivationType::kRELU)); +DEFINE_BUILTIN_OP_IMPORTER(Relu) +{ + return activationHelper(ctx, node, inputs, nvinfer1::ActivationType::kRELU); } DEFINE_BUILTIN_OP_IMPORTER(Reshape) { @@ -1490,65 +1558,40 @@ DEFINE_BUILTIN_OP_IMPORTER(Reshape) { OnnxAttrs attrs(node); new_shape = attrs.get("shape"); } - int infer_dim = -1; - if( input.is_weights() ) { - auto weights = input.weights(); - TRT_CHECK(get_infer_dim(infer_dim,new_shape)); - if (infer_dim >= 0) - { - // Check that the -1 Dimension is correct. - ASSERT(get_shape_size(weights.shape) % (-1*get_shape_size(new_shape)) == 0, - ErrorCode::kINVALID_NODE); - - // Update the dim to the correct value - int new_dim = get_shape_size(weights.shape) / (-1*get_shape_size(new_shape)); - new_shape.d[infer_dim] = new_dim; - weights.shape = new_shape; - ASSERT(get_shape_size(new_shape) == get_shape_size(weights.shape), - ErrorCode::kUNSUPPORTED_NODE); - return {{weights}}; - } - else - { - weights.shape = new_shape; - return {{weights}}; - } + int infer_dim = -1; + if( input.is_weights() ) { + auto weights = input.weights(); + TRT_CHECK(get_infer_dim(infer_dim,new_shape)); + if (infer_dim >= 0) + { + // Update the dim to the correct value + int new_dim = get_shape_size(weights.shape) / (-1 * get_shape_size(new_shape)); + new_shape.d[infer_dim] = new_dim; } - else + ASSERT(get_shape_size(new_shape) == get_shape_size(weights.shape), + ErrorCode::kUNSUPPORTED_NODE); + weights.shape = new_shape; + return {{weights}}; + } + else { + new_shape = set_dims_CHW(remove_dim(new_shape, BATCH_DIM)); + nvinfer1::ITensor& tensor = input.tensor(); + TRT_CHECK(get_infer_dim(infer_dim,new_shape)); + if (infer_dim >= 0) { - nvinfer1::ITensor& tensor = input.tensor(); - new_shape = set_dims_CHW(remove_dim(new_shape, BATCH_DIM)); - // Check for -1 dimension in new shape - TRT_CHECK(get_infer_dim(infer_dim,new_shape)); - - if (infer_dim < 0) { - ASSERT(get_shape_size(new_shape) == - get_shape_size(tensor.getDimensions()), - ErrorCode::kUNSUPPORTED_NODE); - } -#if NV_TENSORRT_MAJOR < 4 - if( new_shape.nbDims == 1 ) { - // Note: TRT implicitly flattens the input to FC layers, and in fact - // requires that it still has 4D shape, so in this case we - // simply ignore the reshape. - RETURN_IDENTITY(inputs.at(0)); - } else { - ASSERT(new_shape.nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); - nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); - ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - layer->setReshapeDimensions(new_shape); - ASSERT(get_shape_size(layer->getOutput(0)->getDimensions()) == - get_shape_size(input.shape()), ErrorCode::kUNSUPPORTED_NODE); - RETURN_FIRST_OUTPUT(layer); - } -#else // NV_TENSORRT_MAJOR >= 4 - nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); - ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - layer->setReshapeDimensions(new_shape); - ASSERT(get_shape_size(layer->getOutput(0)->getDimensions()) == - get_shape_size(input.shape()), ErrorCode::kUNSUPPORTED_NODE); - RETURN_FIRST_OUTPUT(layer); -#endif // NV_TENSORRT_MAJOR >= 4 + // Update the dim to the correct value + int new_dim = get_shape_size(tensor.getDimensions()) / (-1 * get_shape_size(new_shape)); + new_shape.d[infer_dim] = new_dim; + } + + ASSERT(get_shape_size(new_shape) == get_shape_size(tensor.getDimensions()), + ErrorCode::kUNSUPPORTED_NODE); + nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); + ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); + layer->setReshapeDimensions(new_shape); + ASSERT(get_shape_size(layer->getOutput(0)->getDimensions()) == + get_shape_size(input.shape()), ErrorCode::kUNSUPPORTED_NODE); + RETURN_FIRST_OUTPUT(layer); } } @@ -1622,23 +1665,14 @@ DEFINE_BUILTIN_OP_IMPORTER(Shape) { return {{weights}}; } -DEFINE_BUILTIN_OP_IMPORTER(Sigmoid) { - RETURN_FIRST_OUTPUT( - ctx->network()->addActivation( - convertToTensor(inputs.at(0), ctx), nvinfer1::ActivationType::kSIGMOID)); -} - -DEFINE_BUILTIN_OP_IMPORTER(Sin) -{ - return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kSIN); -} - -DEFINE_BUILTIN_OP_IMPORTER(Sinh) +DEFINE_BUILTIN_OP_IMPORTER(Sigmoid) { - return unaryHelper(ctx, node, inputs, nvinfer1::UnaryOperation::kSINH); + return activationHelper(ctx, node, inputs, nvinfer1::ActivationType::kSIGMOID); } DEFINE_BUILTIN_OP_IMPORTER(Size) { + // Can't support tensors because we don't know the batch dim until runtime + ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE); auto shape = inputs.at(0).shape(); nvinfer1::Dims weight_dims; weight_dims.nbDims = 1; @@ -1652,22 +1686,37 @@ DEFINE_BUILTIN_OP_IMPORTER(Size) { } DEFINE_BUILTIN_OP_IMPORTER(Slice) { - // TensorRT 5.1 only supports up to opset 9. - ASSERT(ctx->getOpsetVersion() < 10, ErrorCode::kUNSUPPORTED_NODE); - nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx); - OnnxAttrs attrs(node); - const auto starts = attrs.get>("starts"); - const auto ends = attrs.get>("ends"); - auto axes = attrs.get>("axes"); - // If axes are empty, follow the ONNX spec and populate it with [0, 1, ..., len(starts) - 1] - if (axes.size() == 0) + ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE); + // If opset version >= 10 slice paramerters are weights instead of attributes + nvinfer1::ITensor& tensor = inputs.at(0).tensor(); + std::vector starts; + std::vector ends; + std::vector axes; + std::vector steps; + if(ctx->getOpsetVersion() >= 10) { - for (size_t i = 0; i < starts.size(); i++) - { - axes.push_back(i); - } + ASSERT(inputs.at(1).is_weights(), ErrorCode::kUNSUPPORTED_NODE); + int64_t * array_start = static_cast(inputs.at(1).weights().values); + starts = std::vector (array_start, array_start + inputs.at(1).weights().count()); + ASSERT(inputs.at(2).is_weights(), ErrorCode::kUNSUPPORTED_NODE); + array_start = static_cast(inputs.at(2).weights().values); + ends = std::vector (array_start, array_start + inputs.at(2).weights().count()); + ASSERT(inputs.at(3).is_weights(), ErrorCode::kUNSUPPORTED_NODE); + array_start = static_cast(inputs.at(3).weights().values); + axes = std::vector (array_start, array_start + inputs.at(3).weights().count()); + ASSERT(inputs.at(4).is_weights(), ErrorCode::kUNSUPPORTED_NODE); + array_start = static_cast(inputs.at(4).weights().values); + steps = std::vector (array_start, array_start + inputs.at(4).weights().count()); } - ASSERT(axes.size() == starts.size() && axes.size() == ends.size(), ErrorCode::kINVALID_VALUE); + else + { + OnnxAttrs attrs(node); + starts = attrs.get>("starts"); + ends = attrs.get>("ends"); + axes = attrs.get>("axes"); + steps = std::vector(starts.size(), 1); + } + ASSERT(axes.size() == starts.size() && axes.size() == ends.size() && axes.size() == steps.size(), ErrorCode::kINVALID_VALUE); const nvinfer1::Dims dims = tensor.getDimensions(); const int nbDims = dims.nbDims; @@ -1677,41 +1726,75 @@ DEFINE_BUILTIN_OP_IMPORTER(Slice) { return result; }; nvinfer1::Dims sliceStart = makeDims(0); + nvinfer1::Dims sliceEnd = dims; nvinfer1::Dims sliceSize = dims; - const nvinfer1::Dims sliceStride = makeDims(1); // ONNX has no support for strides in Slice + nvinfer1::Dims sliceStride = makeDims(1); // ONNX has support for strides before opset 10 for (size_t i = 0; i < axes.size(); i++){ + int axis = axes[i]; - if (axis == 0) { - // We can only check that starts is properly 0 - // but can't check end as we don't know batch size - ASSERT(starts[i] == 0, ErrorCode::kINVALID_VALUE); - std::cerr << "Warning: slice with starts=0 on batch axis is ignored" << std::endl; + // Special pass through for no-ops (slice across the whole dimension, [:]) + if (starts[i] == 0 && ends[i] >= dims.d[i] && steps[i] == 1) + { continue; } + + // Convert the axis if it passes the no-op check, we catch actual slices across batch dimension here TRT_CHECK(convert_axis(axis, nbDims)); - int dim = dims.d[axis]; - int start = starts[i] >= 0 ? starts[i] : dim + starts[i]; - int end = ends[i] >= 0 ? ends[i] : dim + ends[i]; - sliceStart.d[axis] = start; - sliceSize.d[axis] = end < dim ? end - start : dim - start; - } + // Check if slice is valid + ASSERT(steps[i] != 0, ErrorCode::kINVALID_VALUE); + sliceStride.d[axis] = steps[i]; + + // Calculate start index + // Support for negative indexing + if(starts[i] < 0) + { + sliceStart.d[axis] = std::max(dims.d[i] + static_cast(starts[i]), 0); + } + else + { + sliceStart.d[axis] = std::min(static_cast(starts[i]), dims.d[i] - 1); + } + + // Calculate end index + // Support for negative indexing + if(ends[i] < 0) + { + // Differs from start because starts is inclusive and ends is exclusive + sliceEnd.d[axis] = std::max(dims.d[i] + static_cast(ends[i]), -1); + } + else + { + sliceEnd.d[axis] = std::min(static_cast(ends[i]), dims.d[i]); + } + + sliceSize.d[axis] = std::max(static_cast(std::ceil(static_cast(sliceEnd.d[axis] - sliceStart.d[axis]) / steps[i])), 0); + } // If entire slice op was a no-op, simply return the input tensor - if (sliceStart == makeDims(0) && sliceSize == dims) + if (sliceSize == makeDims(0)) { return {{&tensor}}; } + else + { + // Slice layer can't handle size of 0 + for (size_t i = 0; i < axes.size(); i++) + { + ASSERT(sliceSize.d[i] != 0, ErrorCode::kINVALID_VALUE); + } + } RETURN_FIRST_OUTPUT(ctx->network()->addSlice(tensor, sliceStart, sliceSize, sliceStride)); } DEFINE_BUILTIN_OP_IMPORTER(Softmax) { OnnxAttrs attrs(node); int axis = attrs.get("axis", 1); - int ndim = inputs.at(0).shape().nbDims; - TRT_CHECK(convert_axis(axis, ndim)); + ASSERT(axis != BATCH_DIM, ErrorCode::kUNSUPPORTED_NODE); + int nbDims = inputs.at(0).shape().nbDims; + TRT_CHECK(convert_axis(axis, nbDims)); nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); nvinfer1::Dims shape = tensor_ptr->getDimensions(); - // Reshape the tensor so that the softmax axis is 0 + // Reshape the tensor so that the softmax axis is 0 for TensorRT to understand if (axis > 0) { ASSERT(tensor_ptr = flatten_tensor(ctx, *tensor_ptr, axis), ErrorCode::kUNSUPPORTED_NODE); @@ -1720,7 +1803,7 @@ DEFINE_BUILTIN_OP_IMPORTER(Softmax) { auto* layer = ctx->network()->addSoftMax(*tensor_ptr); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); tensor_ptr = layer->getOutput(0); - // Reshape the tensor back if it was reshaped above + // Reshape the tensor back to original shape if (axis > 0) { ASSERT(tensor_ptr = move_tensor_dimension(ctx, *tensor_ptr, 0, axis), ErrorCode::kUNSUPPORTED_NODE); @@ -1737,7 +1820,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Softsign) { return activationHelper(ctx, node, inputs, nvinfer1::ActivationType::kSOFTSIGN); } -#if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(SpaceToDepth) { nvinfer1::ITensor* tensor_ptr = &convertToTensor(inputs.at(0), ctx); nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(*tensor_ptr); @@ -1776,7 +1858,6 @@ DEFINE_BUILTIN_OP_IMPORTER(SpaceToDepth) { dims = tensor_ptr->getDimensions(); return {{tensor_ptr}}; } -#endif // NV_TENSORRT_MAJOR >= 4 // TODO: Legacy op for pre-1.0 ONNX spec; can be removed at some point DEFINE_BUILTIN_OP_IMPORTER(SpatialBN) { @@ -1815,7 +1896,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Sqrt) { return apply_unary_function(ctx, inputs.at(0), nvinfer1::UnaryOperation::kSQRT); } -#if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(Squeeze) { nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx); nvinfer1::Dims old_shape = tensor.getDimensions(); @@ -1846,13 +1926,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Squeeze) { get_shape_size(old_shape), ErrorCode::kUNSUPPORTED_NODE); RETURN_FIRST_OUTPUT(layer); } -#endif // NV_TENSORRT_MAJOR >= 4 DECLARE_BUILTIN_OP_IMPORTER(Add); DEFINE_BUILTIN_OP_IMPORTER(Sub) { ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); return combineTensorsElementwise( - ctx, node, inputs, nvinfer1::ElementWiseOperation::kSUB, true); + ctx, node, inputs, nvinfer1::ElementWiseOperation::kSUB, true); } DEFINE_BUILTIN_OP_IMPORTER(Sum) { @@ -1866,9 +1945,7 @@ DEFINE_BUILTIN_OP_IMPORTER(Tan) } DEFINE_BUILTIN_OP_IMPORTER(Tanh) { - RETURN_FIRST_OUTPUT( - ctx->network()->addActivation( - inputs.at(0).tensor(), nvinfer1::ActivationType::kTANH)); + return activationHelper(ctx, node, inputs, nvinfer1::ActivationType::kTANH); } DEFINE_BUILTIN_OP_IMPORTER(ThresholdedRelu) { @@ -1877,28 +1954,35 @@ DEFINE_BUILTIN_OP_IMPORTER(ThresholdedRelu) { return activationHelper(ctx, node, inputs, nvinfer1::ActivationType::kTHRESHOLDED_RELU, &alpha); } -#if NV_TENSORRT_MAJOR >= 4 -DEFINE_BUILTIN_OP_IMPORTER(TopK) { - // TensorRT 5.1 only supports up to opset 9. - ASSERT(ctx->getOpsetVersion() < 10, ErrorCode::kUNSUPPORTED_NODE); - nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx); - ASSERT(tensor.getType() != nvinfer1::DataType::kINT32, - ErrorCode::kUNSUPPORTED_NODE); - OnnxAttrs attrs(node); - ASSERT(attrs.count("k"), ErrorCode::kINVALID_NODE); - int k = attrs.get("k"); - int axis = attrs.get("axis", -1); - int nbDims = tensor.getDimensions().nbDims; - // Adjust axis to TensorRT format - TRT_CHECK(convert_axis(axis, nbDims)); +DEFINE_BUILTIN_OP_IMPORTER(TopK) +{ + ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE); + nvinfer1::ITensor& tensor = inputs.at(0).tensor(); + ASSERT(tensor.getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE); + OnnxAttrs attrs(node); + int axis = attrs.get("axis", -1); + int k; + // Don't support TopK with k as a tensor + if(ctx->getOpsetVersion() >= 10) + { + ASSERT(inputs.at(1).is_weights(), ErrorCode::kUNSUPPORTED_NODE); + ASSERT(inputs.at(1).weights().count() == 1, ErrorCode::kUNSUPPORTED_NODE); + k = *static_cast(inputs.at(1).weights().values); + } + else + { + ASSERT(attrs.count("k"), ErrorCode::kINVALID_NODE); + k = attrs.get("k"); + } - uint32_t axis_mask = 1 << axis; - auto* layer = ctx->network()->addTopK( - tensor, nvinfer1::TopKOperation::kMAX, k, axis_mask); - ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - return {{layer->getOutput(0), layer->getOutput(1)}}; + int nbDims = tensor.getDimensions().nbDims; + TRT_CHECK(convert_axis(axis, nbDims)); + + uint32_t axisMask = 1 << axis; + nvinfer1::ITopKLayer* layer = ctx->network()->addTopK(tensor, nvinfer1::TopKOperation::kMAX, k, axisMask); + ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); + return {{layer->getOutput(0), layer->getOutput(1)}}; } -#endif // NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(Transpose) { TensorOrWeights input = inputs.at(0); @@ -1929,86 +2013,164 @@ DEFINE_BUILTIN_OP_IMPORTER(Transpose) { } } -#if NV_TENSORRT_MAJOR >= 4 DEFINE_BUILTIN_OP_IMPORTER(Unsqueeze) { - nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx); - nvinfer1::Dims old_shape = tensor.getDimensions(); - int ndim_in = old_shape.nbDims; - OnnxAttrs attrs(node); - auto axes = attrs.get>("axes"); - // If the input was already a tensor, then we're dealing with a TRT shape, - // so subtract 1 from the axes. Otherwise, this is an ONNX shape. - if (inputs.at(0).is_tensor()) - { - for (auto& axis : axes) - { - ASSERT(axis != BATCH_DIM, ErrorCode::kUNSUPPORTED_NODE); - --axis; - } - } + nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx); + nvinfer1::Dims old_shape = tensor.getDimensions(); + int ndim_in = old_shape.nbDims; + OnnxAttrs attrs(node); + auto axes = attrs.get>("axes"); + std::set axes_set(axes.begin(), axes.end()); + int ndim_out = ndim_in + axes_set.size(); + ASSERT(ndim_out <= nvinfer1::Dims::MAX_DIMS, ErrorCode::kUNSUPPORTED_NODE); + nvinfer1::Dims new_shape; + new_shape.nbDims = ndim_out; - std::set axes_set(axes.begin(), axes.end()); - int ndim_out = ndim_in + axes_set.size(); - ASSERT(ndim_out <= nvinfer1::Dims::MAX_DIMS, ErrorCode::kUNSUPPORTED_NODE); - nvinfer1::Dims new_shape; - new_shape.nbDims = ndim_out; + // If the input was already a tensor, then we're dealing with a TRT shape, + // so subtract 1 from the axes. Otherwise, this is an ONNX shape. + if (inputs.at(0).is_tensor()) + { + for (auto& axis : axes) + { + ASSERT(axis != BATCH_DIM, ErrorCode::kUNSUPPORTED_NODE); + --axis; + } + } - for (int i = 0, j = 0; j < new_shape.nbDims; ++j ) - { - if( !axes_set.count(j) ) - { - new_shape.d[j] = old_shape.d[i++]; - } - else - { - new_shape.d[j] = 1; - } - } - nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); - ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); - layer->setReshapeDimensions(new_shape); - ASSERT(get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape), - ErrorCode::kUNSUPPORTED_NODE); - RETURN_FIRST_OUTPUT(layer); + for (int i = 0, j = 0; j < new_shape.nbDims; ++j ) + { + if( !axes_set.count(j) ) + { + new_shape.d[j] = old_shape.d[i++]; + } + else + { + new_shape.d[j] = 1; + } + } + + nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); + ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); + layer->setReshapeDimensions(new_shape); + ASSERT(get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape), + ErrorCode::kUNSUPPORTED_NODE); + RETURN_FIRST_OUTPUT(layer); +} + +#if NV_TENSORRT_MAJOR >= 6 +DEFINE_BUILTIN_OP_IMPORTER(Resize) { + // Retrieve and validate input tensor + nvinfer1::ITensor& input = convertToTensor(inputs.at(0), ctx); + int input_dims = input.getDimensions().nbDims; + ASSERT(input_dims > 0, ErrorCode::kUNSUPPORTED_NODE); + + OnnxAttrs attrs(node); + + // Retrive and validate scale factors. + // Scale factors include batch dimensions as well. + ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); + auto scales = inputs.at(1); + ASSERT(scales.is_weights(), ErrorCode::kUNSUPPORTED_NODE); + ShapedWeights scales_weights = scales.weights(); + ASSERT(scales_weights.shape.nbDims == 1, ErrorCode::kUNSUPPORTED_NODE); + // Effective scales_count i.e. drop count by 1 to match input dims. + int scales_count = scales_weights.count() - 1; + ASSERT(scales_count == input_dims, ErrorCode::kUNSUPPORTED_NODE); + ASSERT(scales_weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT, ErrorCode::kINVALID_NODE); + + // Get floating point scale factors. + float const *scales_ptr = static_cast(scales_weights.values); + + // Add resize layer + nvinfer1::IResizeLayer* layer = ctx->network()->addResize(input); + + // Set resize mode + auto mode = attrs.get("mode", "nearest"); + ASSERT(mode == "nearest" || mode == "linear", ErrorCode::kUNSUPPORTED_NODE); + // Set default resize mode. Nearest resize support N-D (where 0 < N <= 8) resize. + nvinfer1::ResizeMode resizeMode = nvinfer1::ResizeMode::kNEAREST; + if (mode == "linear") + { + // Linear resize support 1-D, 2-D and 3D resize. + resizeMode = nvinfer1::ResizeMode::kLINEAR; + } + layer->setResizeMode(resizeMode); + + // Set resize scales + int num_scales = input_dims; + std::vector scale_factors(num_scales, 1.0f); + // Number of scale factors equals input dimensions + // Exclude batch dimension scales here. + // TODO: Update how we validate and set scales when + // TensorRT drop supports full dims. + for (int i = 0; i < num_scales; ++i) + { + scale_factors[i] = scales_ptr[i+1]; + } + layer->setScales(scale_factors.data(), num_scales); + + // Set other attributes. ONNX spec does not have this attribute yet. + // Default: False. Set it any way. + layer->setAlignCorners(false); + + // Return layer output + RETURN_FIRST_OUTPUT(layer); } -#endif // NV_TENSORRT_MAJOR >= 4 +#endif DEFINE_BUILTIN_OP_IMPORTER(Upsample) { + // Retrieve and validate input tensor nvinfer1::ITensor &tensor = convertToTensor(inputs.at(0), ctx); - ASSERT(tensor.getDimensions().nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); + const int nbDims = tensor.getDimensions().nbDims; + // Input tensor has no batch dimension. + const int nbSpatialDims = nbDims - 1; + ASSERT(nbSpatialDims == 2 || nbSpatialDims == 3, ErrorCode::kUNSUPPORTED_NODE); OnnxAttrs attrs(node); - float height_scale, width_scale; + // Resize layer needs scale factors size equals input tensor dims. + const int num_scales = nbDims; + std::vector scale_factors(num_scales, 1.0f); if (ctx->getOpsetVersion() >= 9) { + // Get scale factors from inputs[1] ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); auto scales_input = inputs.at(1); + // Retrieve and validate scale factors. ASSERT(scales_input.is_weights(), ErrorCode::kUNSUPPORTED_NODE); ShapedWeights scales_weights = scales_input.weights(); ASSERT(scales_weights.shape.nbDims == 1, ErrorCode::kUNSUPPORTED_NODE); - ASSERT(scales_weights.count() == 4, ErrorCode::kUNSUPPORTED_NODE); + // Scale factors has batch dimension. + ASSERT(static_cast(scales_weights.count()) == num_scales + 1, ErrorCode::kUNSUPPORTED_NODE); ASSERT(scales_weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT, ErrorCode::kINVALID_NODE); float const *scales_ptr = static_cast(scales_weights.values); - ASSERT(scales_ptr[0] == 1 && scales_ptr[1] == 1, - ErrorCode::kUNSUPPORTED_NODE); - height_scale = scales_ptr[2]; - width_scale = scales_ptr[3]; - } else { - if (!attrs.count("scales")) { - height_scale = attrs.get("height_scale"); - width_scale = attrs.get("width_scale"); - } else { - auto scales = attrs.get>("scales"); - ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE); - ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE); - height_scale = scales[2]; - width_scale = scales[3]; + for(int i = 0; i < num_scales; i++){ + scale_factors[i] = scales_ptr[i + 1]; + } + } + else + { + ASSERT(attrs.count("scales"), ErrorCode::kUNSUPPORTED_NODE); + // Get scale factors from OnnxAttrs. + auto scales = attrs.get>("scales"); + // Scale factors has batch dimension. + ASSERT(static_cast(scales.size()) == nbDims, ErrorCode::kUNSUPPORTED_NODE); + for (int i = 0; i < nbDims; i++) + { + scale_factors[i] = scales[i]; } } - auto scale = {height_scale, width_scale}; auto mode = attrs.get("mode", "nearest"); - ASSERT(mode == "nearest", ErrorCode::kUNSUPPORTED_NODE); - RETURN_FIRST_OUTPUT( - ctx->addPluginV2(new ResizeNearestPlugin(scale), {&inputs.at(0).tensor()})); + ASSERT(mode == "nearest" || mode == "linear", ErrorCode::kUNSUPPORTED_NODE); + // Set default resize mode. Nearest resize support N-D (where 0 < N <= 8) resize. + nvinfer1::ResizeMode resizeMode = nvinfer1::ResizeMode::kNEAREST; + if (mode == "linear") + { + // Linear resize support 1-D, 2-D and 3D resize. + resizeMode = nvinfer1::ResizeMode::kLINEAR; + } + // Add resize layer + nvinfer1::IResizeLayer* const layer = ctx->network()->addResize(tensor); + layer->setScales(scale_factors.data(), num_scales); + layer->setResizeMode(resizeMode); + RETURN_FIRST_OUTPUT(layer); } } // namespace diff --git a/onnx2trt_utils.cpp b/onnx2trt_utils.cpp index d537e1d5..f93a0bec 100644 --- a/onnx2trt_utils.cpp +++ b/onnx2trt_utils.cpp @@ -24,33 +24,38 @@ namespace onnx2trt { +void setAttr(nvinfer1::Dims * trtAttr, ::ONNX_NAMESPACE::AttributeProto const* onnxAttr, int nbSpatialDims, int defaultVal){ + assert(trtAttr->nbDims == nbSpatialDims); + int ndim = onnxAttr->ints().size(); + for(int i = 0; i < nbSpatialDims; ++i){ + if(i < ndim){ + trtAttr->d[i] = onnxAttr->ints(i); + } else { + trtAttr->d[i] = defaultVal; + } + } +} + void get_kernel_params(::ONNX_NAMESPACE::NodeProto const& onnx_node, - nvinfer1::DimsHW const& input_shape, - nvinfer1::DimsHW* kernel_size, - nvinfer1::DimsHW* strides, - nvinfer1::DimsHW* beg_padding, - nvinfer1::DimsHW* end_padding, + nvinfer1::Dims* kernel_size, + nvinfer1::Dims* strides, + nvinfer1::Dims* beg_padding, + nvinfer1::Dims* end_padding, nvinfer1::PaddingMode& paddingMode, - nvinfer1::DimsHW* dilations, - nvinfer1::DimsHW const* output_shape) { + nvinfer1::Dims* dilations) { + const int nbSpatialDims = kernel_size->nbDims; OnnxAttrs attrs(onnx_node); if( attrs.count("kernel_shape") ) { auto const* onnx_kernel_size = attrs.at("kernel_shape"); - int ndim = onnx_kernel_size->ints().size(); - kernel_size->h() = onnx_kernel_size->ints(0); - kernel_size->w() = ndim > 1 ? onnx_kernel_size->ints(1) : 1; + setAttr(kernel_size, onnx_kernel_size, nbSpatialDims, 1); } if( attrs.count("strides") ) { auto const* onnx_strides = attrs.at("strides"); - int ndim = onnx_strides->ints().size(); - strides->h() = onnx_strides->ints(0); - strides->w() = ndim > 1 ? onnx_strides->ints(1) : 1; + setAttr(strides, onnx_strides, nbSpatialDims, 1); } if( dilations && attrs.count("dilations") ) { auto const* onnx_dilations = attrs.at("dilations"); - int ndim = onnx_dilations->ints().size(); - dilations->h() = onnx_dilations->ints(0); - dilations->w() = ndim > 1 ? onnx_dilations->ints(1) : 1; + setAttr(dilations, onnx_dilations, nbSpatialDims, 1); } paddingMode = nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN; auto onnx_auto_pad = attrs.get("auto_pad", std::string("NOTSET")); @@ -58,11 +63,15 @@ void get_kernel_params(::ONNX_NAMESPACE::NodeProto const& onnx_node, if( attrs.count("pads") ) { auto onnx_padding = attrs.get>("pads"); int ndim = onnx_padding.size() / 2; - int i = 0; - beg_padding->h() = onnx_padding.at(i++); - beg_padding->w() = ndim > 1 ? onnx_padding.at(i++) : 0; - end_padding->h() = onnx_padding.at(i++); - end_padding->w() = ndim > 1 ? onnx_padding.at(i++) : 0; + for(int i = 0; i < nbSpatialDims; ++i){ + if(i < ndim){ + beg_padding->d[i] = onnx_padding.at(i); + end_padding->d[i] = onnx_padding.at(i + ndim); + } else { + beg_padding->d[i] = 0; + end_padding->d[i] = 0; + } + } } } else { // SAME_* padding assert(!attrs.count("pads")); diff --git a/onnx2trt_utils.hpp b/onnx2trt_utils.hpp index 4e1b8d50..25d28b43 100644 --- a/onnx2trt_utils.hpp +++ b/onnx2trt_utils.hpp @@ -34,6 +34,20 @@ using std::cerr; using std::endl; +class CeilingPoolDim:public nvinfer1::IOutputDimensionsFormula{ +public: + nvinfer1::DimsHW compute(nvinfer1::DimsHW inputDims, nvinfer1::DimsHW kernelSize, + nvinfer1::DimsHW stride, nvinfer1::DimsHW padding, nvinfer1::DimsHW dilation, const char* layerName) const + { + nvinfer1::DimsHW outputDims; + for (int dimension = 0; dimension < inputDims.nbDims; dimension++) + { + outputDims.d[dimension] = static_cast(ceil((inputDims.d[dimension] + padding.d[dimension] * 2.0 - kernelSize.d[dimension]) / stride.d[dimension] + 1.0)); + } + return outputDims; + } +}; + inline std::ostream& operator<<(std::ostream& stream, nvinfer1::Dims const& shape) { if( shape.nbDims == 0 ) { return stream; @@ -77,19 +91,6 @@ inline std::ostream& operator<<(std::ostream& stream, google::protobuf::Message */ namespace onnx2trt { -inline nvinfer1::ITensor* reshape_tensor(IImporterContext* ctx, nvinfer1::ITensor& tensor, nvinfer1::Dims shape) -{ - if( shape == tensor.getDimensions() ) { - return &tensor; - } - nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); - if( !layer ) { - return nullptr; - } - layer->setReshapeDimensions(shape); - return layer->getOutput(0); -} - inline int get_dtype_size(int32_t onnx_dtype) { switch( onnx_dtype ) { case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return 2; @@ -133,6 +134,20 @@ inline const char* get_dtype_name(int32_t onnx_dtype) { } } +inline nvinfer1::ITensor* reshape_tensor(IImporterContext* ctx, + nvinfer1::ITensor& tensor, + nvinfer1::Dims shape) { + if( shape == tensor.getDimensions() ) { + return &tensor; + } + nvinfer1::IShuffleLayer* layer = ctx->network()->addShuffle(tensor); + if( !layer ) { + return nullptr; + } + layer->setReshapeDimensions(shape); + return layer->getOutput(0); +} + inline void broadcast_tensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2) { if (t1->getDimensions().nbDims == t2->getDimensions().nbDims) @@ -154,7 +169,8 @@ inline void broadcast_tensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvi nvinfer1::Dims largeDims = largeTensor->getDimensions(); nvinfer1::Dims smallDims = smallTensor->getDimensions(); - nvinfer1::Dims newDims({largeDims.nbDims, {1, 1, 1, 1, 1, 1, 1, 1}}); + // Create placeholder dimensions to check broadcasting + nvinfer1::Dims newDims{largeDims.nbDims, {1, 1, 1, 1, 1, 1, 1, 1}}; int i(0), j(0); while (i < smallDims.nbDims && j < largeDims.nbDims) @@ -193,9 +209,9 @@ inline bool convert_dtype(int32_t onnx_dtype, case ::ONNX_NAMESPACE::TensorProto::INT8: *trt_dtype = nvinfer1::DataType::kINT8; break; case ::ONNX_NAMESPACE::TensorProto::FLOAT16: *trt_dtype = nvinfer1::DataType::kHALF; break; #if NV_TENSORRT_MAJOR >= 4 + case ::ONNX_NAMESPACE::TensorProto::INT32: *trt_dtype = nvinfer1::DataType::kINT32; break; // See ShapedWeights.cpp for sanity check if all values can be safetly downcasted to INT32 case ::ONNX_NAMESPACE::TensorProto::INT64: *trt_dtype = nvinfer1::DataType::kINT32; break; - case ::ONNX_NAMESPACE::TensorProto::INT32: *trt_dtype = nvinfer1::DataType::kINT32; break; #endif default: cerr << "Unsupported ONNX data type: " << get_dtype_name(onnx_dtype) @@ -230,7 +246,6 @@ inline bool convert_dims(OnnxDims const& onnx_dims, nvinfer1::Dims& trt_dims) { // TODO: Unknown dimensions are represented using onnx_dim.dim_param // Dynamically sized inputs are currently not supported. Catch these cases // as onnx_dim.dim_value() == 0 on non-batch dimensions and throw an error. - //ASSERT(onnx_dims_vector.empty() || onnx_dim.dim_value() != 0, ErrorCode::kUNSUPPORTED_GRAPH); if (onnx_dims_vector.empty() || onnx_dim.dim_value() != 0) { onnx_dims_vector.push_back(onnx_dim.dim_value()); @@ -298,52 +313,50 @@ inline bool convert_weight_descriptor(onnxTensorDescriptorV1 const &desc, } inline bool convert_onnx_weights(::ONNX_NAMESPACE::TensorProto const& onnx_tensor, - onnx2trt::ShapedWeights* weights) { - nvinfer1::Dims shape; - shape.nbDims = onnx_tensor.dims().size(); - std::copy(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), - shape.d); - // Special case for scalars - if( shape.nbDims == 0 ) { - shape.nbDims = 1; - shape.d[0] = 1; - shape.type[0] = nvinfer1::DimensionType::kCHANNEL; - } - auto dtype = onnx_tensor.data_type(); - void* data_ptr; // TODO: See if can make const* - size_t nbytes; - if( onnx_tensor.raw_data().size() > 0 ) { - data_ptr = (void*)onnx_tensor.raw_data().data(); - nbytes = onnx_tensor.raw_data().size(); - } else if( onnx_tensor.float_data().size() > 0 ) { - assert(dtype == ::ONNX_NAMESPACE::TensorProto::FLOAT); - data_ptr = (void*)onnx_tensor.float_data().data(); - nbytes = onnx_tensor.float_data().size() * sizeof(float); - } else if( onnx_tensor.int32_data().size() > 0 ) { - assert(dtype == ::ONNX_NAMESPACE::TensorProto::INT32 || - dtype == ::ONNX_NAMESPACE::TensorProto::INT16 || - dtype == ::ONNX_NAMESPACE::TensorProto::INT8 || - dtype == ::ONNX_NAMESPACE::TensorProto::UINT16 || - dtype == ::ONNX_NAMESPACE::TensorProto::UINT8 || - dtype == ::ONNX_NAMESPACE::TensorProto::BOOL || - dtype == ::ONNX_NAMESPACE::TensorProto::FLOAT16); - data_ptr = (void*)onnx_tensor.int32_data().data(); - nbytes = onnx_tensor.int32_data().size() * sizeof(int32_t); - } else if( onnx_tensor.int64_data().size() > 0 ) { - assert(dtype == ::ONNX_NAMESPACE::TensorProto::INT64); - data_ptr = (void*)onnx_tensor.int64_data().data(); - nbytes = onnx_tensor.int64_data().size() * sizeof(int64_t); - } else { - // Unsupported ONNX tensor format! - return false; - } + onnx2trt::ShapedWeights* weights) { + nvinfer1::Dims shape; + shape.nbDims = onnx_tensor.dims().size(); + std::copy(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), + shape.d); + auto dtype = onnx_tensor.data_type(); + void* data_ptr; // TODO: See if can make const* + size_t nbytes; + if( onnx_tensor.raw_data().size() > 0 ) + { + data_ptr = (void*)onnx_tensor.raw_data().data(); + nbytes = onnx_tensor.raw_data().size(); + } + else if( onnx_tensor.float_data().size() > 0 ) + { + assert(onnx_tensor.data_type() == ::ONNX_NAMESPACE::TensorProto::FLOAT); + data_ptr = (void*)onnx_tensor.float_data().data(); + nbytes = onnx_tensor.float_data().size() * sizeof(float); + } + else if( onnx_tensor.int32_data().size() > 0 ) + { + // TODO: Need special handling for int8 or float16 stored as int32_data + assert(get_dtype_size(dtype) == 4); + data_ptr = (void*)onnx_tensor.int32_data().data(); + nbytes = onnx_tensor.int32_data().size() * sizeof(int32_t); + } + else if( onnx_tensor.int64_data().size() > 0 ) + { + assert(onnx_tensor.data_type() == ::ONNX_NAMESPACE::TensorProto::INT64); + data_ptr = (void*)onnx_tensor.int64_data().data(); + nbytes = onnx_tensor.int64_data().size() * sizeof(int64_t); + } + else + { + // Unsupported ONNX tensor format! + return false; + } - onnx2trt::ShapedWeights trt_weights(dtype, data_ptr, shape); - (void)nbytes; - assert(trt_weights.size_bytes() == nbytes); - *weights = trt_weights; - return true; -} + onnx2trt::ShapedWeights trt_weights(dtype, data_ptr, shape); + (void)nbytes; + assert(trt_weights.size_bytes() == nbytes); + *weights = trt_weights; + return true; + } // Returns the input if it is already a tensor. If it is of type ShapedWeights, adds a new // constant layer to the TRT network and returns its output. @@ -391,7 +404,7 @@ inline Status convert_axis(int& axis, int nbDims) { axis += nbDims; } - // If axis was positive, subtract 1 to strip batch dimension + // Subtract 1 from the axis given a postive ONNX axis to strip out batch dimension else { axis = axis - 1; @@ -411,8 +424,10 @@ inline int get_conv_output_size(int input_size, int filter_size, // Helper function to help extract the index of a potential -1 dimension in the reshape node inline Status get_infer_dim(int& infer_dim, nvinfer1::Dims const& new_shape) { - for (int i = 0; i < new_shape.nbDims; ++i) { - if (new_shape.d[i] < 0) { + for (int i = 0; i < new_shape.nbDims; ++i) + { + if (new_shape.d[i] < 0) + { // -1 bears special meaning, which means the current dimension can // be inferred while keepin the total number of elements the same. // https://github.com/onnx/onnx/blob/9b9f595107e3fc0295d50f6294d43879df17552f/onnx/defs/tensor/defs.cc#L73-L75 @@ -426,79 +441,29 @@ inline Status get_infer_dim(int& infer_dim, nvinfer1::Dims const& new_shape) { } void get_kernel_params(::ONNX_NAMESPACE::NodeProto const& onnx_node, - nvinfer1::DimsHW const& input_shape, - nvinfer1::DimsHW* kernel_size, - nvinfer1::DimsHW* strides, - nvinfer1::DimsHW* beg_padding, - nvinfer1::DimsHW* end_padding, + nvinfer1::Dims* kernel_size, + nvinfer1::Dims* strides, + nvinfer1::Dims* beg_padding, + nvinfer1::Dims* end_padding, nvinfer1::PaddingMode& paddingMode, - nvinfer1::DimsHW* dilations=nullptr, - nvinfer1::DimsHW const* output_shape=nullptr); - -inline nvinfer1::ScaleMode get_scale_mode(nvinfer1::Dims const& weights_shape) { - if( weights_shape.nbDims == 1 ) { - if( weights_shape.d[0] == 1 ) { - return nvinfer1::ScaleMode::kUNIFORM; - } else { - return nvinfer1::ScaleMode::kCHANNEL; - } - } else { - return nvinfer1::ScaleMode::kELEMENTWISE; - } -} + nvinfer1::Dims* dilations=nullptr); -inline void update_padded_values(std::vector&pad_values, const nvinfer1::DimsHW beg_padding, - const nvinfer1::DimsHW end_padding, const nvinfer1::Dims padded_shape, const float pad_value) +inline nvinfer1::ScaleMode get_scale_mode(nvinfer1::Dims const& weights_shape, + nvinfer1::Dims const& tensor_shape) { - int pad_h = padded_shape.d[1]; - int pad_w = padded_shape.d[2]; - int num_elements = pad_values.size(); - - // Handle H padding. First beg_padding.h * pad_w and last end_padding.h * pad_w - // elements need to be updated to pad_value - if (beg_padding.h() != 0) + if (weights_shape.nbDims == 1) { - int end = beg_padding.h() * pad_w; - for (int i = 0; i < end; i++) + if (weights_shape.d[0] == 1) { - pad_values[i] = pad_value; - } - } - if (end_padding.h() != 0) - { - for (int start = (pad_h - end_padding.h()) * pad_w; - start < num_elements; start++) - { - pad_values[start] = pad_value; - } - - } - // Handle W padding. First beg_padding.w() and last end_padding.w() - // elements of each row needs to be updated to pad_value - if (beg_padding.w() != 0) - { - for (int h_dim = 0; h_dim < pad_h; h_dim++) - { - for (int w_dim = 0; w_dim < beg_padding.w(); w_dim++) - { - int row_base_index = h_dim*pad_h; - pad_values[row_base_index + w_dim] = pad_value; - } - } - } - if (end_padding.w() != 0) - { - for (int h_dim = 0; h_dim < pad_h; h_dim++) + return nvinfer1::ScaleMode::kUNIFORM; + } + // Check for channel wide scale - assume tensor shape is CHW. + else if (weights_shape.d[0] == tensor_shape.d[0]) { - for (int w_dim = pad_w - end_padding.w(); - w_dim < pad_w; w_dim++) - { - int row_base_index = h_dim*pad_h; - pad_values[row_base_index + w_dim] = pad_value; - } + return nvinfer1::ScaleMode::kCHANNEL; } - } + } + return nvinfer1::ScaleMode::kELEMENTWISE; } - } // namespace onnx2trt diff --git a/operators.md b/operators.md index 248c12b8..5a41634f 100644 --- a/operators.md +++ b/operators.md @@ -1,6 +1,6 @@ # Supported ONNX Operators -In general, TensorRT does not support operations across the batch dimension (dimension/axis 0). TensorRT 5.1 supports operators up to Opset 9. Latest information of ONNX operators can be found [here](https://github.com/onnx/onnx/blob/master/docs/Operators.md) +In general, TensorRT does not support operations across the batch dimension (dimension/axis 0). TensorRT 6.0 supports operators up to Opset 10. Latest information of ONNX operators can be found [here](https://github.com/onnx/onnx/blob/master/docs/Operators.md) TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, INT64* @@ -21,20 +21,22 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, | Asinh | Y | 5.1 | | | Atan | Y | 5.1 | | | Atanh | Y | 5.1 | | -| AveragePool | Y | | 2D pooling only. | +| AveragePool | Y | | 2D/3D pooling only. | | BatchNormalization | Y | | | -| Cast | Y | 5.1 | | +| Cast | Y | 5.1 | Only FP16->FP32 casts are supported | | Ceil | Y | | | | Clip | Y | | | | Compress | N | N/A | | | Concat | Y | | | | Constant | Y | | | | ConstantOfShape | N | N/A | | -| Conv | Y | | 2D convolution only. Convolution weights must be baked into the graph. | -| ConvTranspose | Y | | 2D deconvolution only. Deconvolution weights must be baked into the graph. | +| Conv | Y | | 2D/3D convolution only. Convolution weights must be an initializer. | +| ConvInteger | N | N/A | | +| ConvTranspose | Y | | 2D/3D deconvolution only. Deconvolution weights must be an initializer. | | Cos | Y | 5.1 | | | Cosh | Y | | | | DepthToSpace | Y | 4.0 | | +| DequantizeLinear | N | N/A | | | Div | Y | | | | Dropout | Y | | | | Elu | Y | | | @@ -48,15 +50,16 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, | GRU | N | N/A | | | Gather | Y | 4.0 | | | Gemm | Y | | | -| GlobalAveragePool | Y | | 2D pooling only. | +| GlobalAveragePool | Y | | 2D/3D pooling only. | | GlobalLpPool | N | N/A | | -| GlobalMaxPool | Y | | 2D pooling only. | +| GlobalMaxPool | Y | | 2D/3D pooling only. | | Greater | N | N/A | | | HardSigmoid | Y | | | | Hardmax | N | N/A | | | Identity | Y | | | | If | N | N/A | | | InstanceNormalization | Y | | | +| IsInf | N | N/A | | | IsNaN | N | N/A | | | LRN | Y | | | | LSTM | N | N/A | | @@ -68,22 +71,28 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, | LpNormalization | N | N/A | | | LpPool | N | N/A | | | MatMul | Y | | | +| MatMulInteger | N | N/A | | | Max | Y | | | -| MaxPool | Y | | 2D pooling only. | +| MaxPool | Y | | 2D/3D pooling only. | | MaxRoiPool | N | N/A | | | MaxUnpool | N | N/A | | | Mean | Y | 4.0 | | | Min | Y | | | +| Mod | N | N/A | | | Mul | Y | | | | Multinomial | N | N/A | | | Neg | Y | | | +| NonMaxSuppression | N | N/A | | | NonZero | N | N/A | | | Not | N | N/A | | | OneHot | N | N/A | | | Or | N | N/A | | -| PRelu | Y | | | +| PRelu | Y | 6.0 | | | Pad | Y | | Zero-constant padding only. | | Pow | Y | | | +| QLinearConv | N | N/A | | +| QLinearMatMul | N | N/A | | +| QuantizeLinear | N | N/A | | | RNN | N | N/A | | | RandomNormal | N | N/A | | | RandomNormalLike | N | N/A | | @@ -101,6 +110,9 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, | ReduceSum | Y | 4.0 | | | ReduceSumSquare | Y | 4.0 | | | Relu | Y | | | +| Resize | Y | 6.0 | | +| ReverseSequence | N | N/A | | +| RoiAlign | N | N/A | | | Reshape | Y | | | | Scan | N | N/A | | | Scatter | N | N/A | | @@ -120,11 +132,13 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, | Split | Y | | | | Sqrt | Y | | | | Squeeze | Y | 4.0 | | +| StringNormalizer | N | N/A | | | Sub | Y | | | | Sum | Y | | | | Tan | Y | | | | Tanh | Y | 5.1 | | | TfIdfVectorizer | N | N/A | | +| ThresholdedRelu | Y | | | | Tile | N | N/A | | | TopK | Y | 4.0 | | | Transpose | Y | | | @@ -132,14 +146,3 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT32, INT8, | Upsample | Y | 4.0 | | | Where | N | N/A | | | Xor | N | N/A | | -| experimental ATen | N | N/A | | -| experimental Affine | N | N/A | | -| experimental Crop | N | N/A | | -| experimental DynamicSlice | N | N/A | | -| experimental GRUUnit | N | N/A | | -| experimental GivenTensorFill | N | N/A | | -| experimental ImageScaler | Y | | | -| experimental ParametricSoftplus | Y | 5.1 | | -| experimental Scale | N | N/A | | -| experimental ScaledTanh | Y | 5.1 | | -| experimental ThresholdedRelu | Y | | | diff --git a/trt_utils.hpp b/trt_utils.hpp index 093cad33..89d73063 100644 --- a/trt_utils.hpp +++ b/trt_utils.hpp @@ -29,6 +29,7 @@ #include #include #include +#include namespace onnx2trt { @@ -131,6 +132,14 @@ inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims) { return new_dims; } +inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims) { + nvinfer1::Dims newDims; + // Copy dims only if a non-1 has been seen already. + bool non1Seen{false}; + newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d, [&non1Seen](int x) { non1Seen = (x != 1) ? true : non1Seen; return non1Seen; }) - newDims.d; + return newDims; +} + inline nvinfer1::Dims set_dims_CHW(nvinfer1::Dims const& dims) { nvinfer1::Dims new_dims = dims; assert(new_dims.nbDims > 0); @@ -171,7 +180,7 @@ inline TensorOrWeights identity(IImporterContext* ctx, if( input.is_weights() ) { return input; } else { - auto* layer = ctx->network()->addShuffle(input.tensor()); + auto* layer = ctx->network()->addIdentity(input.tensor()); if( !layer ) { return nullptr; }