Skip to content

Commit

Permalink
external stream won't have to sync explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Jan 29, 2024
1 parent f7d3a5a commit d29ef3d
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1684,8 +1684,13 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
}

if (cuda_graph_enable_) {
// cudaStreamSynchronize() is not allowed in cuda graph capture
// cuda graph:
// cudaStreamSynchronize() is not allowed in cuda graph capture.
//
// external stream:
// If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently.

Check warning on line 1691 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1691

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1691:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// So, no need to synchronize different streams after enqueueV3.
if (cuda_graph_enable_ || external_stream_) {
sync_stream_after_enqueue_ = false;
}

Expand Down Expand Up @@ -3557,19 +3562,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView

/*
* Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently,
* TRT EP needs to carefully take care of concurrency here, if not, at least two concurrent issues might happen:
* TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen:
*
* (1) It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream.
* It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream.
* In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently,
* the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT.
* So TRT EP will end up having one trt execution context using multiple streams which is not suggested.
* But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream
* is guaranteed.
*
* (2) TRT enqueueV3() is async and the stream it uses is managed by ORT SessionState::AcquireDeviceStreamCollection() and DeviceStreamCollection.
* So if TRT EP won't wait here for the stream to finish all the operations and instead return right away, the managed stream might still be waiting for
* enqueueV3() to be executed and at the same time, the stream might be re-used by other thread which performances InferenceSession::Run() concurrently.
*
* Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above.
* However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture.
*/
Expand Down Expand Up @@ -3854,19 +3855,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con

/*
* Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently,
* TRT EP needs to carefully take care of concurrency here, if not, at least two concurrent issues might happen:
* TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen:
*
* (1) It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream.
* It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream.
* In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently,
* the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT.
* So TRT EP will end up having one trt execution context using multiple streams which is not suggested.
* But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream
* is guaranteed.
*
* (2) TRT enqueueV3() is async and the stream it uses is managed by ORT SessionState::AcquireDeviceStreamCollection() and DeviceStreamCollection.
* So if TRT EP won't wait here for the stream to finish all the operations and instead return right away, the managed stream might still be waiting for
* enqueueV3() to be executed and at the same time, the stream might be re-used by other thread which performances InferenceSession::Run() concurrently.
*
* Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above.
* However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture.
*/
Expand Down

0 comments on commit d29ef3d

Please sign in to comment.