From e64b07c4c7624c4aa00218be31a452c5cc8409fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Thu, 23 Nov 2023 18:13:20 +0100 Subject: [PATCH] Update CUDA EP docs with NHWC and other new config options (#18547) --- docs/build/eps.md | 5 ++ .../CUDA-ExecutionProvider.md | 82 +++++++++++++++---- 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/docs/build/eps.md b/docs/build/eps.md index b6138ecab1f2a..e3d43278b7c96 100644 --- a/docs/build/eps.md +++ b/docs/build/eps.md @@ -54,6 +54,11 @@ The onnxruntime code will look for the provider shared libraries in the same loc ### Build Instructions {: .no_toc } +With an additional CMake argument the CUDA EP can be compiled with additional NHWC ops. +This option is not enabled by default due to the small amount of supported NHWC operators. +Over time more operators will be added but for now append `--cmake_extra_defines onnxruntime_USE_CUDA_NHWC_OPS=ON` to below build scripts to compile with NHWC operators. +Another very helpful CMake build option is to build with NVTX support (`onnxruntime_ENABLE_NVTX_PROFILE=ON`) that will enable much easier profiling using [Nsight Systems](https://developer.nvidia.com/nsight-systems) and correlates CUDA kernels with their actual ONNX operator. + #### Windows ``` diff --git a/docs/execution-providers/CUDA-ExecutionProvider.md b/docs/execution-providers/CUDA-ExecutionProvider.md index 4f88aceb19566..f73df1ebf0909 100644 --- a/docs/execution-providers/CUDA-ExecutionProvider.md +++ b/docs/execution-providers/CUDA-ExecutionProvider.md @@ -24,7 +24,7 @@ Pre-built binaries of ONNX Runtime with CUDA EP are published for most language ## Requirements -Please reference table below for official GPU packages dependencies for the ONNX Runtime inferencing package. Note that ONNX Runtime Training is aligned with PyTorch CUDA versions; refer to the Training tab on [onnxruntime.ai](https://onnxruntime.ai/) for supported versions. +Please reference table below for official GPU packages dependencies for the ONNX Runtime inferencing package. Note that ONNX Runtime Training is aligned with PyTorch CUDA versions; refer to the Training tab on [onnxruntime.ai](https://onnxruntime.ai/) for supported versions. Note: Because of CUDA Minor Version Compatibility, Onnx Runtime built with CUDA 11.4 should be compatible with any CUDA 11.x version. Please reference [Nvidia CUDA Minor Version Compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/#minor-version-compatibility). @@ -57,10 +57,28 @@ The device ID. Default value: 0 +### user_compute_stream +Defines the compute stream for the inference to run on. +It implicitly sets the `has_user_compute_stream` option. It cannot be set through `UpdateCUDAProviderOptions`, but rather `UpdateCUDAProviderOptionsWithValue`. +This cannot be used in combination with an external allocator. +This can not be set using the python API. + +### do_copy_in_default_stream +Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false, there are race conditions and possibly better performance. + +Default value: true + +### use_ep_level_unified_stream +Uses the same CUDA stream for all threads of the CUDA EP. This is implicitly enabled by `has_user_compute_stream`, `enable_cuda_graph` or when using an external allocator. + +Default value: false + ### gpu_mem_limit The size limit of the device memory arena in bytes. This size limit is only for the execution provider's arena. The total device memory usage may be higher. s: max value of C++ size_t type (effectively unlimited) +_Note:_ Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) + ### arena_extend_strategy The strategy for extending the device memory arena. @@ -71,6 +89,8 @@ kSameAsRequested (1) | extend by the requested amount Default value: kNextPowerOfTwo +_Note:_ Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) + ### cudnn_conv_algo_search The type of search done for cuDNN convolution algorithms. @@ -82,55 +102,78 @@ DEFAULT (2) | default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PR Default value: EXHAUSTIVE -### do_copy_in_default_stream -Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false, there are race conditions and possibly better performance. - -Default value: true ### cudnn_conv_use_max_workspace Check [tuning performance for convolution heavy models](#convolution-heavy-models) for details on what this flag does. -This flag is only supported from the V2 version of the provider options struct when used using the C API. The V2 provider options struct can be created using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0d29cbf555aa806c050748cf8d2dc172) and updated using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a4710fc51f75a4b9a75bde20acbfa0783). Please take a look at the sample below for an example. +This flag is only supported from the V2 version of the provider options struct when used using the C API.(sample below) Default value: 1, for versions 1.14 and later 0, for previous versions ### cudnn_conv1d_pad_to_nc1d Check [convolution input padding in the CUDA EP](#convolution-input-padding) for details on what this flag does. -This flag is only supported from the V2 version of the provider options struct when used using the C API. The V2 provider options struct can be created using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0d29cbf555aa806c050748cf8d2dc172) and updated using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a4710fc51f75a4b9a75bde20acbfa0783). Please take a look at the sample below for an example. +This flag is only supported from the V2 version of the provider options struct when used using the C API. (sample below) Default value: 0 ### enable_cuda_graph Check [using CUDA Graphs in the CUDA EP](#using-cuda-graphs-preview) for details on what this flag does. -This flag is only supported from the V2 version of the provider options struct when used using the C API. The V2 provider options struct can be created using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0d29cbf555aa806c050748cf8d2dc172) and updated using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a4710fc51f75a4b9a75bde20acbfa0783). +This flag is only supported from the V2 version of the provider options struct when used using the C API. (sample below) Default value: 0 ### enable_skip_layer_norm_strict_mode -Whether to use strict mode in SkipLayerNormalization cuda implementation. The default and recommanded setting is false. If enabled, accuracy improvement and performance drop can be expected. -This flag is only supported from the V2 version of the provider options struct when used using the C API. The V2 provider options struct can be created using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0d29cbf555aa806c050748cf8d2dc172) and updated using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a4710fc51f75a4b9a75bde20acbfa0783). +Whether to use strict mode in SkipLayerNormalization cuda implementation. The default and recommanded setting is false. If enabled, accuracy improvement and performance drop can be expected. +This flag is only supported from the V2 version of the provider options struct when used using the C API. (sample below) + +Default value: 0 + +### gpu_external_[alloc|free|empty_cache] + +gpu_external_* is used to pass external allocators. +Example python usage: +```python +from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator +provider_option_map["gpu_external_alloc"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address()) +provider_option_map["gpu_external_free"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_delete_address()) +provider_option_map["gpu_external_empty_cache"] = str(torch_gpu_allocator.gpu_caching_allocator_empty_cache_address()) +``` + +Default value: 0 + +### prefer_nhwc +This option is not available in default builds ! One has to compile ONNX Runtime with `onnxruntime_USE_CUDA_NHWC_OPS=ON`. +If this is enabled the EP prefers NHWC operators over NCHW. Needed transforms will be added to the model. As NVIDIA tensor cores can only work on NHWC layout this can increase performance if the model consists of many supported operators and does not need too many new transpose nodes. Wider operator support is planned in the future. +This flag is only supported from the V2 version of the provider options struct when used using the C API. The V2 provider options struct can be created using [CreateCUDAProviderOptions](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0d29cbf555aa806c050748cf8d2dc172) and updated using [UpdateCUDAProviderOptions](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a4710fc51f75a4b9a75bde20acbfa0783). Default value: 0 ## Performance Tuning -The [I/O Binding feature](../performance/tune-performance/iobinding.md) should be utilized to avoid overhead resulting from copies on inputs and outputs. +The [I/O Binding feature](../performance/tune-performance/iobinding.md) should be utilized to avoid overhead resulting from copies on inputs and outputs. Ideally up and downloads for inputs can be hidden behind the inference. This can be achieved by doing asynchronous copies while running inference. This is demonstrated in this [PR](https://github.com/microsoft/onnxruntime/pull/14088) +```c++ +Ort::RunOptions run_options; +run_options.AddConfigEntry("disable_synchronize_execution_providers", "1"); +session->Run(run_options, io_binding); +``` +By disabling the synchronization on the inference the user has to take care of synchronizing the compute stream after execution. +This feature should only be used with device local memory or an ORT Value allocated in [pinned memory](https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/), otherwise the issued download will be blocking and not behave as desired. ### Convolution-heavy models -ORT leverages CuDNN for convolution operations and the first step in this process is to determine which "optimal" convolution algorithm to use while performing the convolution operation for the given input configuration (input shape, filter shape, etc.) in each `Conv` node. This sub-step involves querying CuDNN for a "workspace" memory size and have this allocated so that CuDNN can use this auxiliary memory while determining the "optimal" convolution algorithm to use. +ORT leverages CuDNN for convolution operations and the first step in this process is to determine which "optimal" convolution algorithm to use while performing the convolution operation for the given input configuration (input shape, filter shape, etc.) in each `Conv` node. This sub-step involves querying CuDNN for a "workspace" memory size and have this allocated so that CuDNN can use this auxiliary memory while determining the "optimal" convolution algorithm to use. The default value of `cudnn_conv_use_max_workspace` is 1 for versions 1.14 or later, and 0 for previous versions. When its value is 0, ORT clamps the workspace size to 32 MB which may lead to a sub-optimal convolution algorithm getting picked by CuDNN. To allow ORT to allocate the maximum possible workspace as determined by CuDNN, a provider option named `cudnn_conv_use_max_workspace` needs to get set (as shown below). Keep in mind that using this flag may increase the peak memory usage by a factor (sometimes a few GBs) but this does help CuDNN pick the best convolution algorithm for the given input. We have found that this is an important flag to use while using an fp16 model as this allows CuDNN to pick tensor core algorithms for the convolution operations (if the hardware supports tensor core operations). This flag may or may not result in performance gains for other data types (`float` and `double`). -* Python +* Python ```python providers = [("CUDAExecutionProvider", {"cudnn_conv_use_max_workspace": '1'})] sess_options = ort.SessionOptions() sess = ort.InferenceSession("my_conv_heavy_fp16_model.onnx", sess_options=sess_options, providers=providers) ``` - + * C/C++ ```c++ @@ -219,7 +262,7 @@ Currently, there are some constraints with regards to using the CUDA Graphs feat * Multi-threaded usage is currently not supported, i.e. `Run()` MAY NOT be invoked on the same `InferenceSession` object from multiple threads while using CUDA Graphs. -NOTE: The very first `Run()` performs a variety of tasks under the hood like making CUDA memory allocations, capturing the CUDA graph for the model, and then performing a graph replay to ensure that the graph runs. Due to this, the latency associated with the first `Run()` is bound to be high. Subsequent `Run()`s only perform graph replays of the graph captured and cached in the first `Run()`. +NOTE: The very first `Run()` performs a variety of tasks under the hood like making CUDA memory allocations, capturing the CUDA graph for the model, and then performing a graph replay to ensure that the graph runs. Due to this, the latency associated with the first `Run()` is bound to be high. Subsequent `Run()`s only perform graph replays of the graph captured and cached in the first `Run()`. * Python @@ -267,10 +310,10 @@ NOTE: The very first `Run()` performs a variety of tasks under the hood like mak void operator()(void* ptr) const { alloc_->Free(ptr); } - + const Ort::Allocator* alloc_; }; - + // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; api.CreateCUDAProviderOptions(&cuda_options); @@ -377,6 +420,10 @@ std::vector values{"0", "2147483648", "kSameAsRequested", "DEFAULT" UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size()); +cudaStream_t cuda_stream; +cudaStreamCreate(&cuda_stream); +// this implicitly sets "has_user_compute_stream" +UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", cuda_stream) OrtSessionOptions* session_options = /* ... */; SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, cuda_options); @@ -419,4 +466,3 @@ cudaProviderOptions.add("cudnn_conv1d_pad_to_nc1d","1"); OrtSession.SessionOptions options = new OrtSession.SessionOptions(); // Must be closed after the session closes options.addCUDA(cudaProviderOptions); ``` -