Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Fix InferenceSession::Run() not thread-safe issue #19301

Merged
merged 10 commits into from
Jan 30, 2024
46 changes: 39 additions & 7 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2529,7 +2529,6 @@
} else if (number_of_trt_nodes == number_of_ort_nodes) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
} else {
sync_stream_after_enqueue_ = true;
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
}

Expand Down Expand Up @@ -3131,7 +3130,7 @@
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(),
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,

Check warning on line 3133 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3133

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3133:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
Expand Down Expand Up @@ -3159,7 +3158,6 @@
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
Expand Down Expand Up @@ -3552,7 +3550,25 @@
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}

if (sync_stream_after_enqueue || dds_output_set.size() > 0) {
/*
* Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently,
* TRT EP needs to carefully take care of concurrency here, if not, at least two concurrent issues might happen:
*
* (1) It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream.
* In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently,
* the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT.
* So TRT EP will end up having one trt execution context using multiple streams which is not suggested.
* But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream
* is guaranteed.
*
* (2) TRT enqueueV3() is async and the stream it uses is managed by ORT SessionState::AcquireDeviceStreamCollection() and DeviceStreamCollection.
* So if TRT EP won't wait here for the stream to finish all the operations and instead return right away, the managed stream might still be waiting for
* enqueueV3() to be executed and at the same time, the stream might be re-used by other thread which performances InferenceSession::Run() concurrently.
*
* Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above.
* However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture.
*/
if (sync_stream_after_enqueue_ && !cuda_graph_enable_) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
}

Expand Down Expand Up @@ -3696,7 +3712,6 @@
&contexts_[context->node_name],
input_info_[context->node_name],
output_info_[context->node_name],
sync_stream_after_enqueue_,
context_memory_sharing_enable_,
&max_ctx_mem_size_,
&tensorrt_mu_};
Expand All @@ -3723,7 +3738,6 @@
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
auto fused_node_name = trt_state->fused_node_name;
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
auto trt_engine = trt_state->engine->get();
auto trt_context = trt_state->context->get();
Expand Down Expand Up @@ -3833,7 +3847,25 @@
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}

if (sync_stream_after_enqueue || dds_output_set.size() > 0) {
/*
* Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently,
* TRT EP needs to carefully take care of concurrency here, if not, at least two concurrent issues might happen:
*
* (1) It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream.
* In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently,
* the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT.
* So TRT EP will end up having one trt execution context using multiple streams which is not suggested.
* But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream
* is guaranteed.
*
* (2) TRT enqueueV3() is async and the stream it uses is managed by ORT SessionState::AcquireDeviceStreamCollection() and DeviceStreamCollection.
* So if TRT EP won't wait here for the stream to finish all the operations and instead return right away, the managed stream might still be waiting for
* enqueueV3() to be executed and at the same time, the stream might be re-used by other thread which performances InferenceSession::Run() concurrently.
*
* Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above.
* However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture.
*/
if (sync_stream_after_enqueue_ && !cuda_graph_enable_) {
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ struct TensorrtFuncState {
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
bool sync_stream_after_enqueue = false;
OrtMutex* tensorrt_mu_ptr = nullptr;
bool fp16_enable = false;
bool int8_enable = false;
Expand Down Expand Up @@ -193,7 +192,6 @@ struct TensorrtShortFuncState {
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
bool sync_stream_after_enqueue = false;
bool context_memory_sharing_enable = false;
size_t* max_context_mem_size_ptr = nullptr;
OrtMutex* tensorrt_mu_ptr = nullptr;
Expand Down Expand Up @@ -335,8 +333,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
cudnnHandle_t external_cudnn_handle_ = nullptr;
cublasHandle_t external_cublas_handle_ = nullptr;

// Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
mutable bool sync_stream_after_enqueue_ = false;
// Call cudaStreamSynchronize() after TRT enqueueV3()
mutable bool sync_stream_after_enqueue_ = true;

CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
Expand Down
Loading