From 315293fe41b4714bacfbe9300f74f389d3040cd3 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Thu, 20 Jun 2024 19:29:23 +0530 Subject: [PATCH] Add INT64 Datatype Support for Shape Tensors in TensorRT Backend (#91) * Add INT64 datatype support for shape tensors --- CMakeLists.txt | 2 + src/instance_state.cc | 109 +++++++++++--------------- src/instance_state.h | 13 ++-- src/shape_tensor.cc | 172 ++++++++++++++++++++++++++++++++++++++++++ src/shape_tensor.h | 77 +++++++++++++++++++ src/tensorrt_utils.cc | 69 ++++++++++++----- src/tensorrt_utils.h | 42 ++++++++++- 7 files changed, 392 insertions(+), 92 deletions(-) create mode 100644 src/shape_tensor.cc create mode 100644 src/shape_tensor.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 62fa538..c54e07a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -141,6 +141,8 @@ add_library( ${SOURCE_DIR}/instance_state.cc ${SOURCE_DIR}/tensorrt_model_instance.cc ${SOURCE_DIR}/tensorrt_model_instance.h + ${SOURCE_DIR}/shape_tensor.cc + ${SOURCE_DIR}/shape_tensor.h ${SOURCE_DIR}/tensorrt_utils.cc ${SOURCE_DIR}/tensorrt_utils.h ${SOURCE_DIR}/filesystem.h diff --git a/src/instance_state.cc b/src/instance_state.cc index 518dd26..653bd4f 100644 --- a/src/instance_state.cc +++ b/src/instance_state.cc @@ -500,7 +500,7 @@ ModelInstanceState::Run( return; } - std::map> request_shape_values; + std::map request_shape_values; // Scheduler ensures all the requests have identical shape values so // use values from first shape tensor TRITONSERVER_Error* err = GetRequestShapeValues( @@ -587,8 +587,7 @@ ModelInstanceState::Run( if (it != request_shape_values.end()) { err = ValidateShapeValues( it->second, citr->second.min_shapes_[io_index], - citr->second.max_shapes_[io_index], citr->second.nb_shape_values_, - support_batching_); + citr->second.max_shapes_[io_index], citr->second.nb_shape_values_); } else { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, @@ -607,8 +606,8 @@ ModelInstanceState::Run( // [FIXME] formalize it, the 'buffer_' may be set directly while forming // the shape value memcpy( - io_binding_info.GetBuffer(), &(it->second[0]), - sizeof(int32_t) * it->second.size()); + io_binding_info.GetBuffer(), it->second.GetData(), + it->second.GetSize()); citr->second.context_->setInputTensorAddress( name.c_str(), io_binding_info.GetBuffer()); } @@ -1304,7 +1303,7 @@ ModelInstanceState::ProcessResponse() TRITONSERVER_Error* ModelInstanceState::GetRequestShapeValues( size_t total_batch_size, TRITONBACKEND_Request* request, - std::map>* request_shape_values) + std::map* request_shape_values) { // Visit all the inputs and extract the shape values present in the // request @@ -1325,12 +1324,6 @@ ModelInstanceState::GetRequestShapeValues( int io_index = io_index_map_[input_name]; if (engine_->isShapeInferenceIO(input_name)) { - auto it = - request_shape_values->emplace(io_index, std::vector()).first; - if (support_batching_) { - it->second.push_back((int32_t)total_batch_size); - } - // For now being conservative and requiring that shape tensors // be in a single buffer on the CPU. We can handle more cases in // future if necessary. @@ -1359,38 +1352,15 @@ ModelInstanceState::GetRequestShapeValues( .c_str()); } - // FIXME DLIS-6653: With the support of INT64, shape tensors - // can also be of type INT64 and the assumptions that shape - // tensors might not always hold true. - // Assuming input shape tensors datatype as INT32. int64_t element_cnt = backend::GetElementCount(shape, dims_count); if (support_batching_) { element_cnt /= shape[0]; } - const size_t expected_byte_size = - element_cnt * GetByteSize(TRITONSERVER_TYPE_INT32, {1}); - bool includes_batch_shape_value = false; - if (expected_byte_size != data_byte_size) { - if (expected_byte_size == (data_byte_size - sizeof(int32_t))) { - includes_batch_shape_value = true; - } else { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("shape tensor for input '") + input_name + - "' expected byte size is " + std::to_string(expected_byte_size) + - " [ or " + std::to_string(expected_byte_size + sizeof(int32_t)) + - " if input includes batch shape value] " + ", got " + - std::to_string(data_byte_size)) - .c_str()); - } - } - - const int32_t* dims = reinterpret_cast(data_buffer); - int64_t offset = includes_batch_shape_value ? 1 : 0; - for (int64_t i = offset; i < element_cnt; ++i) { - it->second.push_back(dims[i]); - } + auto it = request_shape_values->emplace(io_index, ShapeTensor()).first; + RETURN_IF_ERROR(it->second.SetDataFromBuffer( + data_buffer, data_byte_size, datatype, element_cnt, input_name, + support_batching_, total_batch_size)); } } @@ -1401,7 +1371,7 @@ TRITONSERVER_Error* ModelInstanceState::GetMostOptimizedProfile( size_t total_batch_size, TRITONBACKEND_Request** requests, uint32_t request_count, - const std::map>& request_shape_values, + const std::map& request_shape_values, std::map::iterator* citr) { // Returns the TensorRT context that uses profile with shortest @@ -1452,7 +1422,7 @@ TRITONSERVER_Error* ModelInstanceState::EvaluateTensorRTContext( std::map::iterator& citr, size_t total_batch_size, TRITONBACKEND_Request** requests, uint32_t request_count, - const std::map>& request_shape_values, + const std::map& request_shape_values, int64_t* error_distance) { *error_distance = 0; @@ -1519,13 +1489,12 @@ ModelInstanceState::EvaluateTensorRTContext( if (it != request_shape_values.end()) { shape_err = ValidateShapeValues( it->second, citr->second.min_shapes_[io_index], - citr->second.max_shapes_[io_index], citr->second.nb_shape_values_, - support_batching_); - valid_bs = - (!support_batching_) || (((int32_t)total_batch_size >= - *citr->second.min_shapes_[io_index]) && - ((int64_t)total_batch_size <= - *citr->second.max_shapes_[io_index])); + citr->second.max_shapes_[io_index], + citr->second.nb_shape_values_); + valid_bs = (!support_batching_) || + ValidateBatchSize( + total_batch_size, citr->second.min_shapes_[io_index], + citr->second.max_shapes_[io_index]); } else { missing_shape_values = true; } @@ -1549,14 +1518,9 @@ ModelInstanceState::EvaluateTensorRTContext( std::abs(opt_dims.d[idx] - input_shape_vec[idx - 1]); } if (engine_->isShapeInferenceIO(input_name)) { - const auto* opt_shape_values = citr->second.opt_shapes_[io_index]; - *error_distance += - std::abs(*opt_shape_values - (int64_t)total_batch_size); auto it = request_shape_values.find(io_index); - for (size_t idx = 1; idx < citr->second.nb_shape_values_; idx++) { - *error_distance += - std::abs(*(opt_shape_values + idx) - it->second[idx - 1]); - } + *error_distance += it->second.GetDistance( + citr->second.opt_shapes_[io_index], total_batch_size); } } } @@ -2996,13 +2960,14 @@ ModelInstanceState::InitializeShapeInputBinding( return nullptr; } - if (input_datatype != TRITONSERVER_TYPE_INT32) { + if ((input_datatype != TRITONSERVER_TYPE_INT32) && + (input_datatype != TRITONSERVER_TYPE_INT64)) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, (std::string("unexpected datatype TYPE_") + TRITONSERVER_DataTypeString(input_datatype) + - " in model configuration for shape input '" + input_name + - "', expecting TYPE_INT32 for " + Name()) + " in model configuration for shape input '" + input_name + + "', expecting TYPE_INT32 or TYPE_INT64 for " + Name()) .c_str()); } @@ -3042,18 +3007,32 @@ ModelInstanceState::InitializeShapeInputBinding( context.nb_shape_values_ = (context.max_dims_[io_index].nbDims == 0) ? 1 : context.max_dims_[io_index].d[0]; - context.max_shapes_[io_index] = engine_->getProfileTensorValues( - input_name.c_str(), profile_index, nvinfer1::OptProfileSelector::kMAX); - context.min_shapes_[io_index] = engine_->getProfileTensorValues( - input_name.c_str(), profile_index, nvinfer1::OptProfileSelector::kMIN); - context.opt_shapes_[io_index] = engine_->getProfileTensorValues( - input_name.c_str(), profile_index, nvinfer1::OptProfileSelector::kOPT); + context.max_shapes_[io_index] = ShapeTensor(); + context.max_shapes_[io_index].SetDataFromShapeValues( + engine_->getProfileTensorValues( + input_name.c_str(), profile_index, + nvinfer1::OptProfileSelector::kMAX), + input_datatype, context.nb_shape_values_); + + context.min_shapes_[io_index] = ShapeTensor(); + context.min_shapes_[io_index].SetDataFromShapeValues( + engine_->getProfileTensorValues( + input_name.c_str(), profile_index, + nvinfer1::OptProfileSelector::kMIN), + input_datatype, context.nb_shape_values_); + + context.opt_shapes_[io_index] = ShapeTensor(); + context.opt_shapes_[io_index].SetDataFromShapeValues( + engine_->getProfileTensorValues( + input_name.c_str(), profile_index, + nvinfer1::OptProfileSelector::kOPT), + input_datatype, context.nb_shape_values_); // Set shape tensor address to buffer that contains max allowed value so // later shape inference will return max output shape / size for // pre-allocation. if (!context.context_->setInputTensorAddress( - input_name.c_str(), context.max_shapes_[io_index])) { + input_name.c_str(), context.max_shapes_[io_index].GetData())) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, (std::string("trt failed to set the input shape binding for '") + diff --git a/src/instance_state.h b/src/instance_state.h index 1d433d3..1564242 100644 --- a/src/instance_state.h +++ b/src/instance_state.h @@ -35,6 +35,7 @@ #include "io_binding_info.h" #include "model_state.h" #include "semaphore.h" +#include "shape_tensor.h" #include "tensorrt_model_instance.h" #include "triton/backend/backend_input_collector.h" #include "triton/backend/backend_output_responder.h" @@ -136,13 +137,13 @@ struct TensorRTContext { std::vector opt_dims_{}; // Min shape values per bindings - std::vector min_shapes_{}; + std::vector min_shapes_{}; // Max shape values per bindings - std::vector max_shapes_{}; + std::vector max_shapes_{}; // Optimized shape values per bindings - std::vector opt_shapes_{}; + std::vector opt_shapes_{}; // The number of shape values size_t nb_shape_values_{0}; @@ -333,16 +334,16 @@ class ModelInstanceState : public TensorRTModelInstance { TRITONSERVER_Error* GetRequestShapeValues( size_t total_batch_size, TRITONBACKEND_Request* request, - std::map>* request_shape_values); + std::map* request_shape_values); TRITONSERVER_Error* GetMostOptimizedProfile( size_t total_batch_size, TRITONBACKEND_Request** requests, uint32_t request_count, - const std::map>& request_shape_values, + const std::map& request_shape_values, std::map::iterator* citr); TRITONSERVER_Error* EvaluateTensorRTContext( std::map::iterator& citr, size_t total_batch_size, TRITONBACKEND_Request** requests, uint32_t request_count, - const std::map>& request_shape_values, + const std::map& request_shape_values, int64_t* error_distance); bool SetOutputShapeTensorBuffer( diff --git a/src/shape_tensor.cc b/src/shape_tensor.cc new file mode 100644 index 0000000..49930eb --- /dev/null +++ b/src/shape_tensor.cc @@ -0,0 +1,172 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "shape_tensor.h" + +namespace triton { namespace backend { namespace tensorrt { + +TRITONSERVER_Error* +ShapeTensor::SetDataFromBuffer( + const char* data_buffer, size_t data_byte_size, + TRITONSERVER_DataType datatype, size_t nb_shape_values, + const char* input_name, bool support_batching, size_t total_batch_size) +{ + if (data_buffer == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Null data pointer received for Shape tensor"); + } + + if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT32) { + datatype_ = ShapeTensorDataType::INT32; + } else if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT64) { + datatype_ = ShapeTensorDataType::INT64; + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Unsupported data type received for Shape tensor"); + } + + nb_shape_values_ = nb_shape_values; + if (support_batching) { + nb_shape_values_++; // Account for batch size + } + const size_t datatype_size = TRITONSERVER_DataTypeByteSize(datatype); + size_ = nb_shape_values_ * datatype_size; + + TRITONSERVER_Error* err = + ValidateDataByteSize(data_byte_size, input_name, datatype_size); + if (err != nullptr) { + return err; + } + + data_ = std::make_unique(size_); + if (support_batching) { + if (datatype_ == ShapeTensorDataType::INT32) { + *reinterpret_cast(data_.get()) = + static_cast(total_batch_size); + } else if (datatype_ == ShapeTensorDataType::INT64) { + *reinterpret_cast(data_.get()) = + static_cast(total_batch_size); + } + std::memcpy( + data_.get() + datatype_size, data_buffer, size_ - datatype_size); + } else { + std::memcpy(data_.get(), data_buffer, size_); + } + + return nullptr; +} + +TRITONSERVER_Error* +ShapeTensor::SetDataFromShapeValues( + const int32_t* shape_values, TRITONSERVER_DataType datatype, + size_t nb_shape_values) +{ + nb_shape_values_ = nb_shape_values; + const size_t datatype_size = TRITONSERVER_DataTypeByteSize(datatype); + size_ = nb_shape_values_ * datatype_size; + + if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT32) { + datatype_ = ShapeTensorDataType::INT32; + data_.reset(new char[size_]); + int32_t* data_ptr = reinterpret_cast(data_.get()); + std::memcpy(data_ptr, shape_values, size_); + } else if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT64) { + datatype_ = ShapeTensorDataType::INT64; + data_.reset(new char[size_]); + int64_t* data_ptr = reinterpret_cast(data_.get()); + for (size_t i = 0; i < nb_shape_values_; ++i) { + data_ptr[i] = static_cast(shape_values[i]); + } + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Unsupported data type received for Shape tensor"); + } + + return nullptr; +} + +int64_t +ShapeTensor::GetDistance( + const ShapeTensor& other, int64_t total_batch_size) const +{ + int64_t distance = 0; + if (datatype_ == ShapeTensorDataType::INT32) { + const auto* shape_values = reinterpret_cast(data_.get()); + const auto* opt_shape_values = + reinterpret_cast(other.GetData()); + distance += std::abs(*opt_shape_values - total_batch_size); + for (size_t idx = 1; idx < other.GetNbShapeValues(); idx++) { + distance += std::abs(*(opt_shape_values + idx) - shape_values[idx - 1]); + } + } else { + const auto* shape_values = reinterpret_cast(data_.get()); + const auto* opt_shape_values = + reinterpret_cast(other.GetData()); + distance += std::abs(*opt_shape_values - total_batch_size); + for (size_t idx = 1; idx < other.GetNbShapeValues(); idx++) { + distance += std::abs(*(opt_shape_values + idx) - shape_values[idx - 1]); + } + } + return distance; +} + +const char* +ShapeTensor::GetDataTypeString() const +{ + switch (datatype_) { + case ShapeTensorDataType::INT32: + return "INT32"; + case ShapeTensorDataType::INT64: + return "INT64"; + default: + break; + } + return nullptr; +} + +TRITONSERVER_Error* +ShapeTensor::ValidateDataByteSize( + size_t expected_byte_size, const char* input_name, + size_t datatype_size) const +{ + if (expected_byte_size != (size_ - datatype_size) && + (expected_byte_size != size_)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("shape tensor for input '") + input_name + + "' expected byte size is " + std::to_string(expected_byte_size) + + " [ or " + std::to_string(size_) + + " if input includes batch shape value] " + ", got " + + std::to_string(expected_byte_size)) + .c_str()); + } + return nullptr; +} + +}}} // namespace triton::backend::tensorrt diff --git a/src/shape_tensor.h b/src/shape_tensor.h new file mode 100644 index 0000000..1d4fc1a --- /dev/null +++ b/src/shape_tensor.h @@ -0,0 +1,77 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include +#include + +#include "triton/core/tritonserver.h" + +namespace triton { namespace backend { namespace tensorrt { + +enum class ShapeTensorDataType { INT32, INT64 }; + +class ShapeTensor { + public: + ShapeTensor() + : size_(0), nb_shape_values_(0), datatype_(ShapeTensorDataType::INT32) + { + } + + TRITONSERVER_Error* SetDataFromBuffer( + const char* data_buffer, size_t data_byte_size, + TRITONSERVER_DataType datatype, size_t nb_shape_values, + const char* input_name, bool support_batching, size_t total_batch_size); + + TRITONSERVER_Error* SetDataFromShapeValues( + const int32_t* shape_values, TRITONSERVER_DataType datatype, + size_t nb_shape_values); + + int64_t GetDistance(const ShapeTensor& other, int64_t total_batch_size) const; + + const char* GetDataTypeString() const; + + size_t GetSize() const { return size_; } + size_t GetNbShapeValues() const { return nb_shape_values_; } + ShapeTensorDataType GetDataType() const { return datatype_; } + const void* GetData() const { return static_cast(data_.get()); } + + private: + size_t size_; + size_t nb_shape_values_; + ShapeTensorDataType datatype_; + std::unique_ptr data_; + + TRITONSERVER_Error* ValidateDataByteSize( + size_t expected_byte_size, const char* input_name, + size_t datatype_size) const; +}; + +}}} // namespace triton::backend::tensorrt diff --git a/src/tensorrt_utils.cc b/src/tensorrt_utils.cc index 31916ee..2a00a83 100644 --- a/src/tensorrt_utils.cc +++ b/src/tensorrt_utils.cc @@ -367,34 +367,40 @@ ValidateControlDimsDynamic( TRITONSERVER_Error* ValidateShapeValues( - const std::vector& request_shape_values, - const int32_t* min_shape_values, const int32_t* max_shape_values, - size_t nb_shape_values, const bool support_batching) + const ShapeTensor& request_shape_values, + const ShapeTensor& min_shape_values, const ShapeTensor& max_shape_values, + size_t nb_shape_values) { - if (request_shape_values.size() != nb_shape_values) { + if (request_shape_values.GetNbShapeValues() != nb_shape_values) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, (std::string( "mismatch between the number of shape values. Expecting ") + std::to_string(nb_shape_values) + ". Got " + - std::to_string(request_shape_values.size())) + std::to_string(request_shape_values.GetNbShapeValues())) .c_str()); } - for (size_t i = 0; i < nb_shape_values; i++) { - if (request_shape_values[i] < *(min_shape_values + i) || - request_shape_values[i] > *(max_shape_values + i)) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("The shape value at index ") + std::to_string(i) + - " is expected to be in range from " + - std::to_string(*(min_shape_values + i)) + " to " + - std::to_string(*(max_shape_values + i)) + - ", Got: " + std::to_string(request_shape_values[i])) - .c_str()); - } + if (request_shape_values.GetDataType() != min_shape_values.GetDataType() || + request_shape_values.GetDataType() != max_shape_values.GetDataType()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "mismatch between the datatypes of shape values. Expecting ") + + std::string(min_shape_values.GetDataTypeString()) + ". Got " + + std::string(request_shape_values.GetDataTypeString())) + .c_str()); + } + + if (request_shape_values.GetDataType() == ShapeTensorDataType::INT32) { + return CheckShapeTensorInRange( + request_shape_values, min_shape_values, max_shape_values, + nb_shape_values); + } else { + return CheckShapeTensorInRange( + request_shape_values, min_shape_values, max_shape_values, + nb_shape_values); } - return nullptr; } void @@ -517,4 +523,31 @@ IsInput(nvinfer1::ICudaEngine* engine, const std::string& tensor_name) nvinfer1::TensorIOMode::kINPUT; } +bool +ValidateBatchSize( + const size_t total_batch_size, const ShapeTensor& min_shape_values, + const ShapeTensor& max_shape_values) +{ + ShapeTensorDataType datatype = min_shape_values.GetDataType(); + + if (datatype == ShapeTensorDataType::INT32) { + const int32_t* min_values = + static_cast(min_shape_values.GetData()); + const int32_t* max_values = + static_cast(max_shape_values.GetData()); + return ( + (int32_t)total_batch_size >= *min_values && + (int32_t)total_batch_size <= *max_values); + } else if (datatype == ShapeTensorDataType::INT64) { + const int64_t* min_values = + static_cast(min_shape_values.GetData()); + const int64_t* max_values = + static_cast(max_shape_values.GetData()); + return ( + (int64_t)total_batch_size >= *min_values && + (int64_t)total_batch_size <= *max_values); + } + return false; +} + }}} // namespace triton::backend::tensorrt diff --git a/src/tensorrt_utils.h b/src/tensorrt_utils.h index 37c9c12..9944f3b 100644 --- a/src/tensorrt_utils.h +++ b/src/tensorrt_utils.h @@ -31,6 +31,7 @@ #include #include +#include "shape_tensor.h" #include "triton/backend/backend_common.h" #include "triton/core/tritonserver.h" @@ -80,9 +81,15 @@ TRITONSERVER_Error* ValidateControlDimsDynamic( const nvinfer1::Dims& dims, const bool support_batching); TRITONSERVER_Error* ValidateShapeValues( - const std::vector& request_shape_values, - const int32_t* min_shape_values, const int32_t* max_shape_values, - size_t nb_shape_values, const bool support_batching); + const ShapeTensor& request_shape_values, + const ShapeTensor& min_shape_values, const ShapeTensor& max_shape_values, + size_t nb_shape_values); + +template +TRITONSERVER_Error* CheckShapeTensorInRange( + const ShapeTensor& request_shape_values, + const ShapeTensor& min_shape_values, const ShapeTensor& max_shape_values, + size_t nb_shape_values); void DimsToDimVec(const nvinfer1::Dims& model_dims, std::vector* dims); @@ -106,6 +113,10 @@ TRITONSERVER_Error* SupportsIntegratedZeroCopy( bool IsInput(nvinfer1::ICudaEngine* engine, const std::string& tensor_name); +bool ValidateBatchSize( + const size_t total_batch_size, const ShapeTensor& min_shape_values, + const ShapeTensor& max_shape_values); + // // Templates // @@ -146,4 +157,29 @@ ValidateDimension( return nullptr; } +template +TRITONSERVER_Error* +CheckShapeTensorInRange( + const ShapeTensor& request_shape_values, + const ShapeTensor& min_shape_values, const ShapeTensor& max_shape_values, + size_t nb_shape_values) +{ + const T* request_data = + reinterpret_cast(request_shape_values.GetData()); + const T* min_data = reinterpret_cast(min_shape_values.GetData()); + const T* max_data = reinterpret_cast(max_shape_values.GetData()); + for (size_t i = 0; i < nb_shape_values; i++) { + if (request_data[i] < min_data[i] || request_data[i] > max_data[i]) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("The shape value at index ") + std::to_string(i) + + " is expected to be in range from " + std::to_string(min_data[i]) + + " to " + std::to_string(max_data[i]) + + ", Got: " + std::to_string(request_data[i])) + .c_str()); + } + } + return nullptr; +} + }}} // namespace triton::backend::tensorrt