diff --git a/include/onnxruntime/core/framework/op_kernel_context.h b/include/onnxruntime/core/framework/op_kernel_context.h index ac22d9130983a..fa2621440ce30 100644 --- a/include/onnxruntime/core/framework/op_kernel_context.h +++ b/include/onnxruntime/core/framework/op_kernel_context.h @@ -186,6 +186,10 @@ class OpKernelContext { */ AllocatorPtr GetAllocator(const OrtDevice& device) const; +#if defined(ENABLE_ATEN) || defined(USE_TENSORRT) + Status SetOutputMLValue(int index, const OrtValue& ort_value); +#endif + protected: OpKernelContext(concurrency::ThreadPool* threadpool, const logging::Logger& logger, Stream* stream); @@ -195,10 +199,6 @@ class OpKernelContext { const OrtValue* GetImplicitInputMLValue(int index) const; OrtValue* GetOutputMLValue(int index); -#ifdef ENABLE_ATEN - Status SetOutputMLValue(int index, const OrtValue& ort_value); -#endif - // Creates the OrtValue* based on the shape, if it does not exist virtual OrtValue* OutputMLValue(int index, const TensorShape& shape); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4a63018f870a6..bbdee0c91b885 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4512,6 +4512,15 @@ struct OrtApi { * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); + + /** \brief Used for custom operators, set an output of a kernel + * + * \see ::OrtCustomOp + * + * \since Version 1.17. + */ + ORT_API2_STATUS(KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, + _In_ const OrtValue* ort_value); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 467eb31ee2c8e..53b9bead3f1e6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2052,6 +2052,7 @@ struct KernelContext { ConstValue GetInput(size_t index) const; UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; + void SetOutput(size_t index, const OrtValue& ort_value); void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 860a27fc73f79..cb2d3925f9bac 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1634,6 +1634,10 @@ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector(ort_value_idx) >= all_values_size_) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684faa..e7942934ebe30 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -54,7 +54,7 @@ class IExecutionFrame { const OrtValue* GetNodeInputOrOutputMLValue(int index) const; OrtValue* GetMutableNodeInputOrOutputMLValue(int index); -#ifdef ENABLE_ATEN +#if defined(ENABLE_ATEN) || defined(USE_TENSORRT) // Override the index-th output with ort_value Status SetOutputMLValue(int index, const OrtValue& ort_value); #endif diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 94b6224440ed0..31b6141ab985d 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -186,7 +186,7 @@ AllocatorPtr OpKernelContext::GetAllocator(const OrtDevice& device) const { return execution_frame_->GetAllocator(device); } -#ifdef ENABLE_ATEN +#if defined(ENABLE_ATEN) || defined(USE_TENSORRT) Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) { if (index < 0 || index >= OutputCount()) { return Status(common::ONNXRUNTIME, common::FAIL, diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a1fc67ff60b6f..af13d85bb17ff 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -649,6 +649,72 @@ Status ApplyProfileShapesFromInputTensorValue(std::vectorgetOutputShape(); + OrtValue* out = nullptr; + + switch (output_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, allocator->getBuffer(), allocator->getSize(), + shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); + } + } + ctx.SetOutput(output_index, *out); + return Status::OK(); +} + TensorrtExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); @@ -1041,10 +1107,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv throw std::runtime_error("Failed to create directory " + global_cache_path_); } } - { - auto lock = GetApiLock(); - runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); - } } if (engine_decryption_enable_) { @@ -1733,6 +1795,21 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { + // Construct subgraph capability from node list + std::vector> result; + + // If the model only consists of one single "EPContext" contrib op, it means TRT EP can run the precompiled engine directly without + // having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. + // So, simply return the ComputeCapability here. + if (CheckPrecompiledEngine(graph)) { + if (IsValidEPContextNode(graph)) { + SubGraph_t supported_node_vector = {{0}, false}; + std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } + return result; + } + // Get ModelPath const auto& path_string = graph.ModelPath().ToPathString(); #ifdef _WIN32 @@ -1817,9 +1894,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } - // Construct subgraph capability from node list - std::vector> result; - // Handle the case where the graph is subgraph of control flow op. // The purpose is to make control flow op as well as its subgraphs run on TRT. // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. @@ -1907,736 +1981,964 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } -common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - for (auto& fused_node_graph : fused_nodes_and_graphs) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - const Node& fused_node = fused_node_graph.fused_node; - // Build map from input name to its index in input definitions - std::unordered_map input_map; - const auto& input_defs = fused_node.InputDefs(); - input_map.reserve(input_defs.size()); - for (size_t i = 0, end = input_defs.size(); i < end; ++i) { - input_map[input_defs[i]->Name()] = i; - } - - // Build map from output name to its index in output definitions - std::unordered_map output_map; - const auto& output_defs = fused_node.OutputDefs(); - output_map.reserve(output_defs.size()); - for (size_t i = 0, end = output_defs.size(); i < end; ++i) { - output_map[output_defs[i]->Name()] = i; - } - - // Reconstruct graph proto from fused node's function body - auto model = graph_body_viewer.CreateModel(*GetLogger()); - auto model_proto = model->ToProto(); - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - std::string string_buf; - model_proto->SerializeToString(string_buf); +Status TensorrtExecutionProvider::CreateNodeComputeFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + const Node& fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + std::vector& node_compute_funcs) { + if (!IsValidEPContextNode(graph_body_viewer)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EPContext node."); + } + auto node = graph_body_viewer.GetNode(0); + auto& attrs = node->GetAttributes(); + + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index + std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index + std::unordered_map output_types; + + // Deserialize engine + // + // ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0 + if (attrs.at(EP_CONTEXT_ATTR_EMBED_MODE).i() == 0) { + std::filesystem::path engine_cache_path{attrs.at(EP_CONTEXT_ATTR_CACHE_CTX).s()}; + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()); + } + } + + // Build context + // + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + size_t mem_size = trt_engine->getDeviceMemorySize(); + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); + } - if (dump_subgraphs_) { - // Dump TensorRT subgraphs - std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(dump); - } + + // Create input/output to index maps + for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + const auto& iter = input_map.find(name); + if (iter != input_map.end()) { + input_indexes[name] = iter->second; + } + } else { + const auto& iter = output_map.find(name); + if (iter != output_map.end()) { + output_indexes[name] = iter->second; + } + } + } + + // Create output to type map + for (auto node_arg : graph_body_viewer.GetOutputs()) { + auto output_name = node_arg->Name(); + auto& type = node_arg->TypeAsProto()->tensor_type(); + output_types[output_name] = type.elem_type(); + } + + // Save TRT engine, TRT context and input/output info to map + engines_.emplace(fused_node.Name(), std::move(trt_engine)); + contexts_.emplace(fused_node.Name(), std::move(trt_context)); + input_info_[fused_node.Name()].push_back(input_indexes); + output_info_[fused_node.Name()].push_back(output_indexes); + output_info_[fused_node.Name()].push_back(output_types); + + // Create function state + // TODO: remove default capture + NodeComputeInfo compute_info; + compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { + std::unique_ptr p = std::make_unique(); + *p = {context->allocate_func, + context->release_func, + context->allocator_handle, + &engines_[context->node_name], + &contexts_[context->node_name], + input_info_[context->node_name], + output_info_[context->node_name], + sync_stream_after_enqueue_, + dds_output_allocator_map_[context->node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &tensorrt_mu_}; + *state = p.release(); + return 0; + }; - TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); - const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); - trt_config->setMaxWorkspaceSize(max_workspace_size_); + // Release function state + compute_info.release_state_func = [](FunctionState state) { + delete static_cast(state); + }; - // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow - if (fp16_enable_ && layer_norm_fp32_fallback_) { - for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { - auto layer = trt_network->getLayer(idx); - auto next_layer = trt_network->getLayer(idx + 1); - if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; - layer->setPrecision(nvinfer1::DataType::kFLOAT); - next_layer->setPrecision(nvinfer1::DataType::kFLOAT); - layer->setOutputType(0, nvinfer1::DataType::kFLOAT); - next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); - } + // Create compute function + compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + TensorrtShortFuncState* trt_state = reinterpret_cast(state); + + // The whole compute_function should be considered the critical section. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; + auto& dds_output_allocator_map = trt_state->dds_output_allocator_map; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + if (alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); + } + OrtAllocator* alloc = alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); } } - int num_inputs = trt_network->getNbInputs(); - int num_outputs = trt_network->getNbOutputs(); - std::unordered_map input_indexes(num_inputs); - std::unordered_map output_indexes(num_outputs); - std::unordered_map output_types(num_outputs); - /* - * Initialize shape range for each dynamic shape input tensor: - * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time. - * It won't make adjustment for profile values during EP compute time. - * - * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN]. - * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value. - * - * - * Once the TRT profiles are created: - * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config - * and the engine will be built at EP compile time. - * - * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above, - * and all the profiles won't be applied and engine won't be built until EP compute time. + * Set input shapes and bind input buffers */ - bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. - bool has_explicit_profile = false; - bool apply_explicit_profile = false; - int num_profiles = 0; - std::vector trt_profiles; - - // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor. - // - // (1) Single profile case: - // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b - // has one dynamic shape dimension: dim_1. The data will be: - // { - // tensor_a: { - // dim_0: [[min_shape, max_shape, opt_shape]], - // dim_2: [[min_shape, max_shape, opt_shape]] - // }, - // tensor_b: { - // dim_1: [[min_shape, max_shape, opt_shape]] - // } - // } - // - // (2) Multiple profiles case: - // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1, - // and both of the tensors have two profiles. The data will be: - // { - // tensor_a: { - // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] - // }, - // tensor_b: { - // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] - // } - // } - ShapeRangesMap input_explicit_shape_ranges; - ShapeRangesMap input_implicit_shape_ranges; - - if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { - has_explicit_profile = true; - num_profiles = GetNumProfiles(profile_min_shapes_); - for (int i = 0; i < num_profiles; i++) { - trt_profiles.push_back(trt_builder->createOptimizationProfile()); - } - } - - // Iterate all input tensors to check dynamic shape - for (unsigned int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_network->getInput(i); - const std::string& input_name = input->getName(); - nvinfer1::Dims dims = input->getDimensions(); - int nb_dims = dims.nbDims; - - // Apply explicit optimization profiles provided by user - if (has_explicit_profile) { - apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); - } + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); - // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time - if (!apply_explicit_profile) { - if (input->isShapeTensor()) { - // Shape tensor - std::vector> profile_vector; - std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; - profile_vector.push_back(shape_vector); // only one profile needed - input_implicit_shape_ranges[input_name][0] = profile_vector; - has_dynamic_shape = true; - } else { - // Execution tensor + if (trt_engine->isShapeInferenceIO(input_name)) { + // Get the shape value of shape tensor + nvinfer1::Dims dims = trt_engine->getTensorShape(input_name); + int nb_dims = dims.nbDims; + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); + std::vector shape_values(shape_size, 1); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = input[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = static_cast(input[j]); + } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); + } + } + + // Bind input tensor which is shape tensor + if (!trt_context->setTensorAddress(input_name, &shape_values[0])) { + std::string error_input_name = input_name; + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'")); + } + } else { + // Set shape for input tensor which is execution tensor + nvinfer1::Dims dims = trt_context->getTensorShape(input_name); + int nb_dims = dims.nbDims; for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == -1) { - std::vector> profile_vector; - std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; - profile_vector.push_back(shape_vector); // only one profile needed - input_implicit_shape_ranges[input_name][j] = profile_vector; - has_dynamic_shape = true; + dims.d[j] = static_cast(tensor_shapes[j]); + } + if (!trt_context->setInputShape(input_name, dims)) { + std::string error_input_name = input_name; + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); + } + // Bind input buffers + void* data = nullptr; + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + data = scratch_buffers.back().get(); + } else { + data = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); + data = scratch_buffers.back().get(); + } else { + data = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); + data = scratch_buffers.back().get(); + } else { + data = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); + data = scratch_buffers.back().get(); + } else { + data = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); + data = scratch_buffers.back().get(); + } else { + data = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + data = scratch_buffers.back().get(); + } else { + data = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + data = scratch_buffers.back().get(); + } else { + SafeInt input_dim_size = 1; + for (int j = 0, end = nb_dims; j < end; ++j) { + if (tensor_shapes[j] == 0) { + input_dim_size = 1; + break; + } else { + input_dim_size *= tensor_shapes[j]; + } + } + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(int32_t))); + data = scratch_buffers.back().get(); + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64 + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + data = scratch_buffers.back().get(); + } else { + SafeInt input_dim_size = 1; + for (int j = 0, end = nb_dims; j < end; ++j) { + if (tensor_shapes[j] == 0) { + input_dim_size = 1; + break; + } else { + input_dim_size *= tensor_shapes[j]; + } + } + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(float))); + data = scratch_buffers.back().get(); + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); + } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); } } - } - apply_explicit_profile = false; + trt_context->setTensorAddress(input_name, data); } } - // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user - if (has_explicit_profile) { - // TRT EP has a constraint here. - // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options. - if (has_dynamic_shape) { - std::ostringstream msg; - msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n"; - msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; - msg << "Following input(s) has no associated shape profiles provided: "; - auto begin = input_implicit_shape_ranges.begin(); - auto end = input_implicit_shape_ranges.end(); - auto it = begin; - if (it != end) { - msg << it->first; - ++it; - } - for (; it != end; ++it) { - msg << "," << it->first; - } - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str()); - } else { - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + std::unordered_set dds_output_set; + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + // Get output shape + nvinfer1::Dims dims = trt_context->getTensorShape(output_name); + int nb_dims = dims.nbDims; + bool is_dds_output = false; + std::vector output_shapes(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + // data-dependent shape + if (dims.d[j] == -1) { + is_dds_output = true; + dds_output_set.emplace(output_name); + break; } + output_shapes[j] = dims.d[j]; } - } - // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. - // It will later set proper min/max/opt shape values duing EP compute time. - else if (!has_explicit_profile && has_dynamic_shape) { - trt_profiles.push_back(trt_builder->createOptimizationProfile()); - } - // Check platform availability for low precision - if (fp16_enable_) { - if (!trt_builder->platformHasFastFp16()) { - fp16_enable_ = false; - LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; } - } - if (int8_enable_) { - if (!trt_builder->platformHasFastInt8()) { - int8_enable_ = false; - LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; - } - } - - // Load INT8 calibration table - std::unordered_map dynamic_range_map; - if (int8_enable_ && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); - } - } - - // Set precision flags - std::string trt_node_name_with_precision = fused_node.Name(); - if (fp16_enable_ && int8_enable_) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - trt_node_name_with_precision += "_fp16_int8"; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; - } else if (fp16_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - trt_node_name_with_precision += "_fp16"; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; - } else if (int8_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - trt_node_name_with_precision += "_int8"; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; - } - - // Set DLA - if (fp16_enable_ || int8_enable_) { - if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 - int number_of_dla_core = trt_builder->getNbDLACores(); - if (number_of_dla_core == 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; - dla_enable_ = false; - } else { - if (dla_core_ >= number_of_dla_core) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead."; - dla_core_ = 0; + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. + // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. + // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, + // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) + // + // Otherwise, if the shape of the output tensor is known prioir to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. + if (is_dds_output) { + if (dds_output_allocator_map.find(output_name) == dds_output_allocator_map.end()) { + auto allocator = new OutputAllocator(alloc); + trt_context->setOutputAllocator(output_name, allocator); + dds_output_allocator_map[output_name] = allocator; + } + } else { + output_tensors[i] = ctx.GetOutput(output_index, output_shapes); + auto& output_tensor = output_tensors[i]; + switch (output_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + buffers[output_name] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + buffers[output_name] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + buffers[output_name] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + buffers[output_name] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + buffers[output_name] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + buffers[output_name] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + buffers[output_name] = scratch_buffers.back().get(); + output_dim_sizes[i] = 1; + } else { + SafeInt output_dim_size(1); + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == 0) { + output_dim_size = 1; + break; + } else { + output_dim_size *= dims.d[j]; + } + } + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int32_t))); + buffers[output_name] = scratch_buffers.back().get(); + output_dim_sizes[i] = output_dim_size; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + buffers[output_name] = scratch_buffers.back().get(); + } else { + SafeInt output_dim_size(1); + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == 0) { + output_dim_size = 1; + break; + } else { + output_dim_size *= dims.d[j]; + } + } + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(float))); + buffers[output_name] = scratch_buffers.back().get(); + output_dim_sizes[i] = output_dim_size; + } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(dla_core_); - trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); } + trt_context->setTensorAddress(output_name, buffers[output_name]); } } - // enable sparse weights - if (sparsity_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { + size_t mem_size = trt_engine->getDeviceMemorySize(); + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + } + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } - // enable builder heuristics - if (build_heuristics_enable_) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } -#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 - // switch optimizaion level - if (builder_optimization_level_ != 3) { - trt_config->setBuilderOptimizationLevel(builder_optimization_level_); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(); } - // limit auxiliary streams - if (auxiliary_streams_ >= 0) { - trt_config->setMaxAuxStreams(auxiliary_streams_); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } -#else - if (builder_optimization_level_ != 3) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (auxiliary_streams_ >= 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } -#endif - // limit used tactic sources - if (!tactic_sources_.empty()) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= GetTacticSourceFromString(tactic_sources_); - trt_config->setTacticSources(tactics); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + + if (sync_stream_after_enqueue || dds_output_set.size() > 0) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Build TRT engine (if needed) and load TRT engine if: - // (1) Graph has no dynamic shape input - // (2) All the dynamic shape inputs have associated explicit profiles specified by user - // - // Otherwise engine will be handled at inference time. - std::unique_ptr trt_engine; - std::unique_ptr trt_context; + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - cudaDeviceProp prop; - CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); - std::string compute_capability = GetComputeCapacity(prop); + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } - if (!has_dynamic_shape) { - const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); - const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; - std::string timing_cache_path = ""; - bool engine_update = false; - if (timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(global_cache_path_, prop); + if (dds_output_set.find(output_name) != dds_output_set.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + } } - { - // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. - auto lock = GetApiLock(); - // If explicit profile flag is on and engine cache enable flag is on, - // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. - if (has_explicit_profile && engine_cache_enable_) { - engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); - if (engine_update) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; - } else { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; - } + auto& output_tensor = output_tensors[i]; + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); } + } + } - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from cache: " + engine_cache_path); - } - } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { - // Decrypt engine - size_t engine_size = 0; - if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; - if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); - } - } else { - // Set INT8 per tensor dynamic range - if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { - trt_config->setInt8Calibrator(nullptr); - if (!SetDynamicRange(*trt_network, dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); - } - } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (timing_cache_enable_) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not create timing cache: " + timing_cache_path); - } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } - } - - // Build engine - std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - trt_engine = std::unique_ptr(trt_builder->buildEngineWithConfig(*trt_network, *trt_config)); - if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build engine for fused node: " + fused_node.Name()); - } - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } - if (engine_cache_enable_) { - // Serialize engine profile if it has explicit profiles - if (has_explicit_profile) { - SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - } - - std::unique_ptr serializedModel(trt_engine->serialize()); - size_t engine_size = serializedModel->size(); - if (engine_decryption_enable_) { - // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. - if (engine_encryption_ != nullptr) { - if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP call to engine encryption library failed"); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; - } else { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; - } - } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serializedModel->data()), engine_size); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; - } - } - // serialize and save timing cache - if (timing_cache_enable_) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } - } - } - } - - // Build context - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { - size_t mem_size = trt_engine->getDeviceMemorySize(); - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture + // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. + // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(); + // CUDA work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph()); } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } - if (!trt_context) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); + IncrementRegularRunCountBeforeGraphCapture(); } } - // Create input to index map - for (int i = 0; i < num_inputs; ++i) { - auto input = trt_network->getInput(i); - const std::string& input_name = input->getName(); - const auto& iter = input_map.find(input_name); - if (iter != input_map.end()) { - input_indexes[input_name] = iter->second; - } + return Status::OK(); + }; + + return Status::OK(); +} + +Status TensorrtExecutionProvider::CreateNodeComputeFromOrtGraph(const GraphViewer& graph_body_viewer, + const Node& fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + std::vector& node_compute_funcs) { + // Reconstruct graph proto from fused node's function body + auto model = graph_body_viewer.CreateModel(*GetLogger()); + auto model_proto = model->ToProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + std::string string_buf; + model_proto->SerializeToString(string_buf); + + if (dump_subgraphs_) { + // Dump TensorRT subgraphs + std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); + model_proto->SerializeToOstream(dump); + } + + TensorrtLogger& trt_logger = GetTensorrtLogger(); + auto trt_builder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + trt_config->setMaxWorkspaceSize(max_workspace_size_); + + // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow + if (fp16_enable_ && layer_norm_fp32_fallback_) { + for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { + auto layer = trt_network->getLayer(idx); + auto next_layer = trt_network->getLayer(idx + 1); + if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + layer->setPrecision(nvinfer1::DataType::kFLOAT); + next_layer->setPrecision(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + } + } + } + + int num_inputs = trt_network->getNbInputs(); + int num_outputs = trt_network->getNbOutputs(); + std::unordered_map input_indexes(num_inputs); + std::unordered_map output_indexes(num_outputs); + std::unordered_map output_types(num_outputs); + + /* + * Initialize shape range for each dynamic shape input tensor: + * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time. + * It won't make adjustment for profile values during EP compute time. + * + * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN]. + * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value. + * + * + * Once the TRT profiles are created: + * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config + * and the engine will be built at EP compile time. + * + * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above, + * and all the profiles won't be applied and engine won't be built until EP compute time. + */ + bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. + bool has_explicit_profile = false; + bool apply_explicit_profile = false; + int num_profiles = 0; + std::vector trt_profiles; + + // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor. + // + // (1) Single profile case: + // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b + // has one dynamic shape dimension: dim_1. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape, max_shape, opt_shape]], + // dim_2: [[min_shape, max_shape, opt_shape]] + // }, + // tensor_b: { + // dim_1: [[min_shape, max_shape, opt_shape]] + // } + // } + // + // (2) Multiple profiles case: + // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1, + // and both of the tensors have two profiles. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + // }, + // tensor_b: { + // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + // } + // } + ShapeRangesMap input_explicit_shape_ranges; + ShapeRangesMap input_implicit_shape_ranges; + + if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { + has_explicit_profile = true; + num_profiles = GetNumProfiles(profile_min_shapes_); + for (int i = 0; i < num_profiles; i++) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); } + } - // Create output to index and type maps - const auto& graph_output = model_proto->graph().output(); - for (int i = 0; i < num_outputs; ++i) { - const std::string& output_name = trt_network->getOutput(i)->getName(); - const auto& iter = output_map.find(output_name); - if (iter != output_map.end()) { - output_indexes[output_name] = iter->second; - } - const auto& tensor_type = graph_output[i].type().tensor_type(); - output_types[output_name] = tensor_type.elem_type(); - } - - // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node.Name(), std::move(trt_parser)); - engines_.emplace(fused_node.Name(), std::move(trt_engine)); - contexts_.emplace(fused_node.Name(), std::move(trt_context)); - builders_.emplace(fused_node.Name(), std::move(trt_builder)); - networks_.emplace(fused_node.Name(), std::move(trt_network)); - input_info_[fused_node.Name()].push_back(input_indexes); - output_info_[fused_node.Name()].push_back(output_indexes); - output_info_[fused_node.Name()].push_back(output_types); - input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges; - profiles_.emplace(fused_node.Name(), std::move(trt_profiles)); - - // Create function state - // TODO: remove default capture - NodeComputeInfo compute_info; - compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { - std::unique_ptr p = std::make_unique(); - // translate tactic sources string to nvinfer1::TacticSources - nvinfer1::TacticSources tactics = 0; - if (!tactic_sources_.empty()) { - tactics = GetTacticSourceFromString(tactic_sources_); - } - *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], - &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, - dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, - runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, - dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, - global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, - builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics}; - *state = p.release(); - return 0; - }; + // Iterate all input tensors to check dynamic shape + for (unsigned int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; - // Release function state - compute_info.release_state_func = [](FunctionState state) { - delete static_cast(state); - }; + // Apply explicit optimization profiles provided by user + if (has_explicit_profile) { + apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + } - // Create compute function - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { - Ort::KernelContext ctx(context); - - TensorrtFuncState* trt_state = reinterpret_cast(state); - - // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, - // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. - // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); - const std::unordered_map& input_indexes = (trt_state->input_info)[0]; - const std::unordered_map& output_indexes = (trt_state->output_info)[0]; - const std::unordered_map& output_types = (trt_state->output_info)[1]; - bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; - auto fused_node_name = trt_state->fused_node_name; - auto& shape_ranges = trt_state->input_shape_ranges; - auto trt_builder = trt_state->builder->get(); - auto trt_engine = trt_state->engine->get(); - auto trt_context = trt_state->context->get(); - auto trt_profiles = trt_state->profiles; - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - int num_inputs = static_cast(input_indexes.size()); - int num_outputs = static_cast(output_indexes.size()); - bool engine_update = false; - bool context_update = false; - std::unordered_set input_names; - std::unordered_map> tensor_shape_values; - - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); - if (alloc_ == nullptr) { - Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); - } - OrtAllocator* alloc = alloc_; - - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - cudaDeviceProp prop; - CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); - std::string compute_capability = GetComputeCapacity(prop); - - // Prepare cache name - const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); - const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; - std::string timing_cache_path = ""; - if (timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(global_cache_path_, prop); - } - - // Load serialized engine - if (trt_state->engine_cache_enable && trt_engine == nullptr) { - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && !trt_state->engine_decryption_enable && profile_file) { - // Deserialize profile - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - - // Prepare buffer - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Decrypt engine - size_t engine_size = 0; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); + // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time + if (!apply_explicit_profile) { + if (input->isShapeTensor()) { + // Shape tensor + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][0] = profile_vector; + has_dynamic_shape = true; + } else { + // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == -1) { + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][j] = profile_vector; + has_dynamic_shape = true; } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; } } + apply_explicit_profile = false; + } + } - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); - - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. - // TRT EP will help determine the min/max/opt profile values based on current input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); - } - } + // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user + if (has_explicit_profile) { + // TRT EP has a constraint here. + // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options. + if (has_dynamic_shape) { + std::ostringstream msg; + msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n"; + msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; + msg << "Following input(s) has no associated shape profiles provided: "; + auto begin = input_implicit_shape_ranges.begin(); + auto end = input_implicit_shape_ranges.end(); + auto it = begin; + if (it != end) { + msg << it->first; + ++it; + } + for (; it != end; ++it) { + msg << "," << it->first; + } + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str()); + } else { + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); } + } + } + // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. + // It will later set proper min/max/opt shape values duing EP compute time. + else if (!has_explicit_profile && has_dynamic_shape) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } - // Regenerate engine - if (engine_update) { - // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. - trt_state->context->reset(); - trt_state->engine->reset(); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); - } + // Check platform availability for low precision + if (fp16_enable_) { + if (!trt_builder->platformHasFastFp16()) { + fp16_enable_ = false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; + } + } - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { - trt_config->setInt8Calibrator(nullptr); - if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); - } + if (int8_enable_) { + if (!trt_builder->platformHasFastInt8()) { + int8_enable_ = false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + } + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } + + // Set precision flags + std::string trt_node_name_with_precision = fused_node.Name(); + if (fp16_enable_ && int8_enable_) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + trt_node_name_with_precision += "_fp16_int8"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; + } else if (fp16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + trt_node_name_with_precision += "_fp16"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } else if (int8_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + trt_node_name_with_precision += "_int8"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } + + // Set DLA + if (fp16_enable_ || int8_enable_) { + if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 + int number_of_dla_core = trt_builder->getNbDLACores(); + if (number_of_dla_core == 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + dla_enable_ = false; + } else { + if (dla_core_ >= number_of_dla_core) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead."; + dla_core_ = 0; } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + } + } + } - // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - } + // enable sparse weights + if (sparsity_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } - // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); - } + // enable builder heuristics + if (build_heuristics_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } +#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 + // switch optimizaion level + if (builder_optimization_level_ != 3) { + trt_config->setBuilderOptimizationLevel(builder_optimization_level_); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } - // enable sparse weights - if (trt_state->sparsity_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; - } + // limit auxiliary streams + if (auxiliary_streams_ >= 0) { + trt_config->setMaxAuxStreams(auxiliary_streams_); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + } +#else + if (builder_optimization_level_ != 3) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (auxiliary_streams_ >= 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + // limit used tactic sources + if (!tactic_sources_.empty()) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= GetTacticSourceFromString(tactic_sources_); + trt_config->setTacticSources(tactics); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + } + + // Build TRT engine (if needed) and load TRT engine if: + // (1) Graph has no dynamic shape input + // (2) All the dynamic shape inputs have associated explicit profiles specified by user + // + // Otherwise engine will be handled at inference time. + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + cudaDeviceProp prop; + CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); + std::string compute_capability = GetComputeCapacity(prop); + + if (!has_dynamic_shape) { + const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); + const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; + std::string timing_cache_path = ""; + bool engine_update = false; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, prop); + } + { + // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. + auto lock = GetApiLock(); - // enable builder heuristics - if (trt_state->build_heuristics_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } -#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 - // switch optimizaion level - if (trt_state->builder_optimization_level != 3) { - trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + // If explicit profile flag is on and engine cache enable flag is on, + // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. + if (has_explicit_profile && engine_cache_enable_) { + engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (engine_update) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; } + } - // limit auxiliary streams - if (trt_state->auxiliary_streams >= 0) { - trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from cache: " + engine_cache_path); } -#else - if (trt_state->builder_optimization_level != 3) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { + // Decrypt engine + size_t engine_size = 0; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); } - if (trt_state->auxiliary_streams >= 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + // Deserialize engine + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } -#endif - // limit used tactic sources - if (trt_state->filter_tactic_sources) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= trt_state->tactic_sources; - trt_config->setTacticSources(tactics); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } else { + // Set INT8 per tensor dynamic range + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { + trt_config->setInt8Calibrator(nullptr); + if (!SetDynamicRange(*trt_network, dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); + } } // Load timing cache from file. Create a fresh cache if the file doesn't exist std::unique_ptr timing_cache = nullptr; - if (trt_state->timing_cache_enable) { + if (timing_cache_enable_) { std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); if (timing_cache == nullptr) { @@ -2650,37 +2952,34 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine) = std::unique_ptr( - trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); } - if (!(*(trt_state->engine))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + trt_engine = std::unique_ptr(trt_builder->buildEngineWithConfig(*trt_network, *trt_config)); + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build engine for fused node: " + fused_node.Name()); } - trt_engine = trt_state->engine->get(); - if (trt_state->engine_cache_enable) { - // Serialize engine profile - SerializeProfileV2(profile_cache_path, shape_ranges); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + if (engine_cache_enable_) { + // Serialize engine profile if it has explicit profiles + if (has_explicit_profile) { + SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + } - // Serialize engine std::unique_ptr serializedModel(trt_engine->serialize()); size_t engine_size = serializedModel->size(); - if (trt_state->engine_decryption_enable) { + if (engine_decryption_enable_) { // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. - if (trt_state->engine_encryption != nullptr) { - if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + if (engine_encryption_ != nullptr) { + if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); + "TensorRT EP call to engine encryption library failed"); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; } else { @@ -2689,12 +2988,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(serializedModel->data()), engine_size); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; } } - // serialize and save timing cache - if (trt_state->timing_cache_enable) { + if (timing_cache_enable_) { auto timing_cache = trt_config->getTimingCache(); std::unique_ptr timingCacheHostData{timing_cache->serialize()}; if (timingCacheHostData == nullptr) { @@ -2706,399 +3004,796 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext_memory_sharing_enable) { - *(trt_state->context) = std::unique_ptr( - trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); - } else { - *(trt_state->context) = std::unique_ptr( - trt_state->engine->get()->createExecutionContext()); + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + size_t mem_size = trt_engine->getDeviceMemorySize(); + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); + } + } + + // Create input to index map + for (int i = 0; i < num_inputs; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); + if (iter != input_map.end()) { + input_indexes[input_name] = iter->second; + } + } + + // Create output to index and type maps + const auto& graph_output = model_proto->graph().output(); + for (int i = 0; i < num_outputs; ++i) { + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); + if (iter != output_map.end()) { + output_indexes[output_name] = iter->second; + } + const auto& tensor_type = graph_output[i].type().tensor_type(); + output_types[output_name] = tensor_type.elem_type(); + } + + // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(fused_node.Name(), std::move(trt_parser)); + engines_.emplace(fused_node.Name(), std::move(trt_engine)); + contexts_.emplace(fused_node.Name(), std::move(trt_context)); + builders_.emplace(fused_node.Name(), std::move(trt_builder)); + networks_.emplace(fused_node.Name(), std::move(trt_network)); + input_info_[fused_node.Name()].push_back(input_indexes); + output_info_[fused_node.Name()].push_back(output_indexes); + output_info_[fused_node.Name()].push_back(output_types); + input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges; + profiles_.emplace(fused_node.Name(), std::move(trt_profiles)); + + // Create function state + // TODO: remove default capture + NodeComputeInfo compute_info; + compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { + std::unique_ptr p = std::make_unique(); + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(tactic_sources_); + } + *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], + &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, + runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, + dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, + global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, + builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics}; + *state = p.release(); + return 0; + }; + + // Release function state + compute_info.release_state_func = [](FunctionState state) { + delete static_cast(state); + }; + + // Create compute function + compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + TensorrtFuncState* trt_state = reinterpret_cast(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, + // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; + auto fused_node_name = trt_state->fused_node_name; + auto& shape_ranges = trt_state->input_shape_ranges; + auto trt_builder = trt_state->builder->get(); + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto trt_profiles = trt_state->profiles; + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + bool engine_update = false; + bool context_update = false; + std::unordered_set input_names; + std::unordered_map> tensor_shape_values; + + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + if (alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); + } + OrtAllocator* alloc = alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + cudaDeviceProp prop; + CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); + std::string compute_capability = GetComputeCapacity(prop); + + // Prepare cache name + const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; + std::string timing_cache_path = ""; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, prop); + } + + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } - if (!(*(trt_state->context))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } - trt_context = trt_state->context->get(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; } + } - // Get input and output binding names - int total_bindings = trt_engine->getNbBindings(); - std::vector buffers(total_bindings); - std::vector input_binding_names, output_binding_names; - for (int i = 0, end = total_bindings; i < end; ++i) { - if (trt_engine->bindingIsInput(i)) { - input_binding_names.push_back(trt_engine->getBindingName(i)); - } else { - output_binding_names.push_back(trt_engine->getBindingName(i)); + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); + + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. + // TRT EP will help determine the min/max/opt profile values based on current input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); } } + } - // Set input shapes and assign input buffers - std::vector> scratch_buffers; - for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { - const std::string& input_name = input_binding_names[i]; - int binding_index = trt_engine->getBindingIndex(input_name.c_str()); - if (binding_index == -1) { - continue; + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } + + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } + } + + // Set precision + if (trt_state->fp16_enable && trt_state->int8_enable) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + } else if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + } + + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } - size_t input_index = 0; - const auto iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } + + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } +#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } +#else + if (trt_state->builder_optimization_level != 3) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not create timing cache: " + timing_cache_path); } - auto input_tensor = ctx.GetInput(input_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } - // Set dynamic shapes - nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast(binding_index)); - int nb_dims = dimensions.nbDims; - if (input_names.count(input_name) == 1) { - if (trt_engine->isShapeBinding(binding_index)) { - trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); - } else { - for (int j = 0, end = nb_dims; j < end; ++j) { - dimensions.d[j] = static_cast(tensor_shapes[j]); - } - const bool status = trt_context->setBindingDimensions(binding_index, dimensions); - if (!status) { - ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP cannot set the dynamic dimensions of a binding")); + // Build engine + { + auto lock = GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + *(trt_state->engine) = std::unique_ptr( + trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + } + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + + // Serialize engine + std::unique_ptr serializedModel(trt_engine->serialize()); + size_t engine_size = serializedModel->size(); + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), engine_size); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; } + } - const auto input_type = tensor_info.GetElementType(); - switch (input_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = const_cast(input_tensor_ptr); - } - break; + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + context_update = true; + } + + if (context_update) { + if (trt_state->context_memory_sharing_enable) { + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); + } else { + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); + } + if (!(*(trt_state->context))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbBindings(); + std::vector buffers(total_bindings); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + if (trt_engine->bindingIsInput(i)) { + input_binding_names.push_back(trt_engine->getBindingName(i)); + } else { + output_binding_names.push_back(trt_engine->getBindingName(i)); + } + } + + // Set input shapes and assign input buffers + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + const std::string& input_name = input_binding_names[i]; + int binding_index = trt_engine->getBindingIndex(input_name.c_str()); + if (binding_index == -1) { + continue; + } + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + // Set dynamic shapes + nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dimensions.nbDims; + if (input_names.count(input_name) == 1) { + if (trt_engine->isShapeBinding(binding_index)) { + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + } else { + for (int j = 0, end = nb_dims; j < end; ++j) { + dimensions.d[j] = static_cast(tensor_shapes[j]); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = const_cast(input_tensor_ptr); - } - break; + const bool status = trt_context->setBindingDimensions(binding_index, dimensions); + if (!status) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP cannot set the dynamic dimensions of a binding")); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = const_cast(input_tensor_ptr); - } - break; + } + } + + const auto input_type = tensor_info.GetElementType(); + switch (input_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = const_cast(input_tensor_ptr); - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = const_cast(input_tensor_ptr); - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = const_cast(input_tensor_ptr); - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + SafeInt input_dim_size = 1; + for (int j = 0, end = nb_dims; j < end; ++j) { + if (tensor_shapes[j] == 0) { + input_dim_size = 1; + break; + } else { + input_dim_size *= tensor_shapes[j]; } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(int32_t))); - buffers[binding_index] = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(buffers[binding_index]), input_dim_size); } - break; + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(int32_t))); + buffers[binding_index] = scratch_buffers.back().get(); + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(buffers[binding_index]), input_dim_size); } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64 + auto input_tensor_ptr = input_tensor.GetTensorData(); + if (input_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + SafeInt input_dim_size = 1; + for (int j = 0, end = nb_dims; j < end; ++j) { + if (tensor_shapes[j] == 0) { + input_dim_size = 1; + break; + } else { + input_dim_size *= tensor_shapes[j]; } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(float))); - buffers[binding_index] = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(buffers[binding_index]), input_dim_size); } - break; - } - default: { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP input onnx tensor data type: " + std::to_string(input_type) + " not supported."); + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(float))); + buffers[binding_index] = scratch_buffers.back().get(); + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(buffers[binding_index]), input_dim_size); } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP input onnx tensor data type: " + std::to_string(input_type) + " not supported."); } } + } - // Set output shapes and assign output buffers - std::vector output_dim_sizes(num_outputs, 1); - using OutputOrtValue = Ort::UnownedValue; - std::vector output_tensors; - output_tensors.reserve(num_outputs); - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - // Set dynamic shapes - const std::string& output_name = output_binding_names[i]; - int binding_index = trt_engine->getBindingIndex(output_name.c_str()); - if (binding_index == -1) { - continue; - } + // Set output shapes and assign output buffers + std::vector output_dim_sizes(num_outputs, 1); + using OutputOrtValue = Ort::UnownedValue; + std::vector output_tensors; + output_tensors.reserve(num_outputs); + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + // Set dynamic shapes + const std::string& output_name = output_binding_names[i]; + int binding_index = trt_engine->getBindingIndex(output_name.c_str()); + if (binding_index == -1) { + continue; + } - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } - nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); - int nb_dims = dimensions.nbDims; - std::vector output_shapes(nb_dims); - for (int j = 0, end = nb_dims; j < end; ++j) { - output_shapes[j] = dimensions.d[j]; - } - output_tensors.push_back(ctx.GetOutput(output_index, output_shapes)); + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dimensions.nbDims; + std::vector output_shapes(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + output_shapes[j] = dimensions.d[j]; + } + output_tensors.push_back(ctx.GetOutput(output_index, output_shapes)); - size_t output_type = 0; - const auto type_iter = output_types.find(output_name); - if (type_iter != output_types.end()) { - output_type = type_iter->second; - } + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } - auto& output_tensor = output_tensors.back(); - switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = output_tensor_ptr; - } - break; + auto& output_tensor = output_tensors.back(); + switch (output_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = output_tensor_ptr; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = output_tensor_ptr; - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = output_tensor_ptr; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = output_tensor_ptr; - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = output_tensor_ptr; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = output_tensor_ptr; - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = output_tensor_ptr; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = output_tensor_ptr; - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = output_tensor_ptr; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - buffers[binding_index] = output_tensor_ptr; - } - break; + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + buffers[binding_index] = output_tensor_ptr; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[binding_index] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(output_dim_sizes[i]); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dimensions.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dimensions.d[j]; - } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); + buffers[binding_index] = scratch_buffers.back().get(); + output_dim_sizes[i] = 1; + } else { + SafeInt output_dim_size(output_dim_sizes[i]); + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dimensions.d[j] == 0) { + output_dim_size = 1; + break; + } else { + output_dim_size *= dimensions.d[j]; } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int32_t))); - buffers[binding_index] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; } - break; + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int32_t))); + buffers[binding_index] = scratch_buffers.back().get(); + output_dim_sizes[i] = output_dim_size; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[binding_index] = scratch_buffers.back().get(); - } else { - SafeInt output_dim_size(output_dim_sizes[i]); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dimensions.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dimensions.d[j]; - } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr == nullptr) { + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); + buffers[binding_index] = scratch_buffers.back().get(); + } else { + SafeInt output_dim_size(output_dim_sizes[i]); + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dimensions.d[j] == 0) { + output_dim_size = 1; + break; + } else { + output_dim_size *= dimensions.d[j]; } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(float))); - buffers[binding_index] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; } - break; - } - default: { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(float))); + buffers[binding_index] = scratch_buffers.back().get(); + output_dim_sizes[i] = output_dim_size; } + break; } - } - - // Set execution context memory - if (trt_state->context_memory_sharing_enable) { - size_t mem_size = trt_engine->getDeviceMemorySize(); - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } + } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(); + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { + size_t mem_size = trt_engine->getDeviceMemorySize(); + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; } + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + } - // Run TRT inference - if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); - } + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(); + } - if (sync_stream_after_enqueue) { - cudaStreamSynchronize(stream); - } + // Run TRT inference + if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); + } + + if (sync_stream_after_enqueue) { + cudaStreamSynchronize(stream); + } - // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - const std::string& output_name = output_binding_names[i]; - size_t binding_index = trt_engine->getBindingIndex(output_name.c_str()); - size_t output_type = 0; - const auto& iter = output_types.find(output_name); - if (iter != output_types.end()) { - output_type = iter->second; + // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + const std::string& output_name = output_binding_names[i]; + size_t binding_index = trt_engine->getBindingIndex(output_name.c_str()); + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + auto& output_tensor = output_tensors[i]; + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]); } - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]); - } - } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]); - } + } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]); } } + } - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured()) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(); - // CUDA work issued to a capturing stream doesn’t actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph()); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture + // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. + // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(); + // CUDA work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph()); + } else { + IncrementRegularRunCountBeforeGraphCapture(); } + } - return Status::OK(); - }; + return Status::OK(); + }; - node_compute_funcs.push_back(compute_info); + node_compute_funcs.push_back(compute_info); + return Status::OK(); +} + +common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + for (auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; + // Build map from input name to its index in input definitions + std::unordered_map input_map; + const auto& input_defs = fused_node.InputDefs(); + input_map.reserve(input_defs.size()); + for (size_t i = 0, end = input_defs.size(); i < end; ++i) { + input_map[input_defs[i]->Name()] = i; + } + + // Build map from output name to its index in output definitions + std::unordered_map output_map; + const auto& output_defs = fused_node.OutputDefs(); + output_map.reserve(output_defs.size()); + for (size_t i = 0, end = output_defs.size(); i < end; ++i) { + output_map[output_defs[i]->Name()] = i; + } + + { + if (!runtime_) { + auto lock = GetApiLock(); + runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); + } + } + + Status status; + if (CheckPrecompiledEngine(graph_body_viewer)) { + status = CreateNodeComputeFromPrecompiledEngine(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs); + } else { + status = CreateNodeComputeFromOrtGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs); + } + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index cda08715ea009..4d42e9cd1ba3a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -97,6 +97,61 @@ template using unique_pointer = std::unique_ptr; }; // namespace tensorrt_ptr +template +inline T RoundUp(T m, T n) { + return ((m + n - 1) / n) * n; +} + +// +// Class to allocate memory for outputs with data-dependent shapes. The sizes of those are unknown so pre-allocation is +// not possible. +// +class OutputAllocator : public nvinfer1::IOutputAllocator { + public: + OutputAllocator(OrtAllocator* alloc) + : allocator(alloc) { + } + + void* reallocateOutput( + char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + buffer = IAllocator::MakeUniquePtrFromOrtAllocator(allocator, RoundUp(size, alignment)); + allocated_size = size; + } + return buffer.get(); + } + + void* getBuffer() { + return buffer.get(); + } + + void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override { + output_shapes.reserve(dims.nbDims); + for (int i = 0; i < dims.nbDims; i++) { + output_shapes.push_back(dims.d[i]); + } + } + + std::vector& getOutputShape() { + return output_shapes; + } + + uint64_t getSize() { + return allocated_size; + } + + ~OutputAllocator() override {} + + private: + OrtAllocator* allocator = nullptr; + IAllocatorUniquePtr buffer; + uint64_t allocated_size = 0; + std::vector output_shapes; +}; + using ShapeRangesMap = std::unordered_map>>>; // Information to construct kernel function state. @@ -145,6 +200,21 @@ struct TensorrtFuncState { bool cuda_graph_enable = 0; }; +struct TensorrtShortFuncState { + AllocateFunc test_allocate_func = nullptr; + DestroyFunc test_release_func = nullptr; + AllocatorHandle allocator = nullptr; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::vector> input_info; + std::vector> output_info; + bool sync_stream_after_enqueue = false; + std::unordered_map dds_output_allocator_map; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; + OrtMutex* tensorrt_mu_ptr = nullptr; +}; + // Holds important information for building valid ORT graph. struct SubGraphContext { std::unordered_set output_args; @@ -153,6 +223,7 @@ struct SubGraphContext { }; using SubGraphContextMap = std::unordered_map>; +using DDSOutputAllocatorMap = std::unordered_map; // Logical device representation. class TensorrtExecutionProvider : public IExecutionProvider { @@ -261,6 +332,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map>> profile_opt_shapes_; std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; + std::unordered_map dds_output_allocator_map_; // For DDS output tensor // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture cudnnHandle_t external_cudnn_handle_ = nullptr; @@ -452,6 +524,17 @@ class TensorrtExecutionProvider : public IExecutionProvider { */ bool IsLocalValue(const Graph& graph, const std::string& name) const; + Status CreateNodeComputeFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + const Node& fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + std::vector& node_compute_funcs); + Status CreateNodeComputeFromOrtGraph(const GraphViewer& graph_body_viewer, + const Node& fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + std::vector& node_compute_funcs); + bool IsGraphCaptureAllowed() const; void CaptureBegin(); void CaptureEnd(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index 6bbeab7e94ce4..b2beed42d7c9c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" @@ -13,6 +14,10 @@ #include "core/common/path_string.h" #include "core/framework/murmurhash3.h" +static const std::string EP_CONTEXT_OP_TYPE = "EPContext"; +static const std::string EP_CONTEXT_ATTR_EMBED_MODE = "embed_mode"; +static const std::string EP_CONTEXT_ATTR_CACHE_CTX = "ep_cache_context"; + namespace fs = std::experimental::filesystem; namespace onnxruntime { @@ -694,4 +699,30 @@ bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_mapOpType() == EP_CONTEXT_OP_TYPE) { + return true; + } + return false; +} + +/* + * The sanity check for EP context contrib op. + */ +bool IsValidEPContextNode(const GraphViewer& graph) { + assert(graph.NumberOfNodes() == 1); + assert(graph.GetNode(0)->OpType() == EP_CONTEXT_OP_TYPE); + auto node = graph.GetNode(0); + auto& attrs = node->GetAttributes(); + if (attrs.count(EP_CONTEXT_ATTR_EMBED_MODE) > 0 && attrs.count(EP_CONTEXT_ATTR_CACHE_CTX) > 0) { + // ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0 + if (attrs.at(EP_CONTEXT_ATTR_EMBED_MODE).i() == 0 && !std::filesystem::exists(attrs.at(EP_CONTEXT_ATTR_CACHE_CTX).s())) { + LOGS_DEFAULT(ERROR) << "Can't find " << attrs.at(EP_CONTEXT_ATTR_CACHE_CTX).s() << " TensorRT engine"; + return false; + } + } + return true; +} } // namespace onnxruntime diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 041250adc3fc0..3189346b44ed9 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -310,6 +310,22 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* API_IMPL_END }; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const OrtValue* ort_value) { + API_IMPL_BEGIN +#if defined(ENABLE_ATEN) || defined(USE_TENSORRT) + auto status = reinterpret_cast(context)->SetOutputMLValue(gsl::narrow_cast(index), *ort_value); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); +#else + ORT_UNUSED_PARAMETER(context); + ORT_UNUSED_PARAMETER(index); + ORT_UNUSED_PARAMETER(ort_value); + return CreateStatus(ORT_FAIL, "TensorRT execution provider is not enabled in this build."); +#endif + API_IMPL_END +}; + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size) { API_IMPL_BEGIN std::string value; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9f8786b727ac1..d729b2a957a84 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2721,6 +2721,7 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::ShapeInferContext_SetOutputTypeShape, &OrtApis::SetSymbolicDimensions, &OrtApis::ReadOpAttr, + &OrtApis::KernelContext_SetOutput, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 09c83219ad2c8..2c7f501c4b8b0 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -184,6 +184,7 @@ ORT_API_STATUS_IMPL(KernelContext_GetInputCount, _In_ const OrtKernelContext* co ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out); +ORT_API_STATUS_IMPL(KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const OrtValue* ort_value); // OrtTypeInfo methods ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); diff --git a/onnxruntime/test/providers/cpu/nn/dropout_op_test.cc b/onnxruntime/test/providers/cpu/nn/dropout_op_test.cc index 5860d3167ce67..8d7d46316381b 100644 --- a/onnxruntime/test/providers/cpu/nn/dropout_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/dropout_op_test.cc @@ -30,7 +30,9 @@ TEST(Dropout, WithOptionalOutputOpset10) { test.AddInput("X", dims, {1.0f, 2.0f, 3.0f, 5.0f}); test.AddOutput("Y", dims, {1.0f, 2.0f, 3.0f, 5.0f}); test.AddOutput("mask", dims, {false, false, false, false}); - test.Run(); + // The fix in onnx-tensorrt parser for dropout onnx node is not included in TRT 8.6.1 but might be included in later ORT release. + // Simply skip this for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(Dropout, WithOptionalOutputOpset7) {