Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CUDA EP docs with NHWC and other new config options #18547

Merged
merged 3 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/build/eps.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ The onnxruntime code will look for the provider shared libraries in the same loc
### Build Instructions
{: .no_toc }

With an additional CMake argument the CUDA EP can be compiled with additional NHWC ops.
This option is not enabled by default due to the small amount of supported NHWC operators.
Over time more operators will be added but for now append `--cmake_extra_defines onnxruntime_USE_CUDA_NHWC_OPS=ON` to below build scripts to compile with NHWC operators.
Another very helpful CMake build option is to build with NVTX support (`onnxruntime_ENABLE_NVTX_PROFILE=ON`) that will enable much easier profiling using [Nsight Systems](https://developer.nvidia.com/nsight-systems) and correlates CUDA kernels with their actual ONNX operator.

#### Windows

```
Expand Down
45 changes: 40 additions & 5 deletions docs/execution-providers/CUDA-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,26 @@ The device ID.

Default value: 0

### has_user_compute_stream
Defines that the user has set a `user_compute_stream`.
This cannot be used in combination with an external allocator.
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved

### do_copy_in_default_stream
Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false, there are race conditions and possibly better performance.

Default value: true

### use_ep_level_unified_stream
Uses the same CUDA stream for all threads of the CUDA EP. This is implicitly enabled by `has_user_compute_stream`, `enable_cuda_graph` or when using an external allocator.

Default value: false

### gpu_mem_limit
The size limit of the device memory arena in bytes. This size limit is only for the execution provider's arena. The total device memory usage may be higher.
s: max value of C++ size_t type (effectively unlimited)

_Note:_ Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)

### arena_extend_strategy
The strategy for extending the device memory arena.

Expand All @@ -71,6 +87,8 @@ kSameAsRequested (1) | extend by the requested amount

Default value: kNextPowerOfTwo

_Note:_ Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)

### cudnn_conv_algo_search
The type of search done for cuDNN convolution algorithms.

Expand All @@ -82,10 +100,6 @@ DEFAULT (2) | default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PR

Default value: EXHAUSTIVE

### do_copy_in_default_stream
Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false, there are race conditions and possibly better performance.

Default value: true

### cudnn_conv_use_max_workspace
Check [tuning performance for convolution heavy models](#convolution-heavy-models) for details on what this flag does.
Expand All @@ -112,8 +126,29 @@ This flag is only supported from the V2 version of the provider options struct w

Default value: 0

### gpu_external_alloc
### gpu_external_free
### gpu_external_empty_cache
### tunable_op_enable
### tunable_op_tuning_enable
### tunable_op_max_tuning_duration_ms
gedoensmax marked this conversation as resolved.
Show resolved Hide resolved

### prefer_nhwc
This option is not available in default builds ! One has to compile ONNX Runtime with `onnxruntime_USE_CUDA_NHWC_OPS=ON`.
If this is enabled the EP prefers NHWC operators over NCHW. Needed transforms will be added to the model. As NVIDIA tensor cores can only work on NHWC layout this can increase performance if the model consists of many supported operators and does not need too many new transpose nodes. Wider operator support is planned in the future.
This flag is only supported from the V2 version of the provider options struct when used using the C API. The V2 provider options struct can be created using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0d29cbf555aa806c050748cf8d2dc172) and updated using [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a4710fc51f75a4b9a75bde20acbfa0783).

Default value: 0

## Performance Tuning
The [I/O Binding feature](../performance/tune-performance/iobinding.md) should be utilized to avoid overhead resulting from copies on inputs and outputs.
The [I/O Binding feature](../performance/tune-performance/iobinding.md) should be utilized to avoid overhead resulting from copies on inputs and outputs. Ideally up and downloads for inputs can be hidden behind the inference. This can be achieved by doing asynchronous copies while running inference. This is demonstrated in this [PR](https://github.com/microsoft/onnxruntime/pull/14088)
```
Ort::RunOptions run_options;
run_options.AddConfigEntry("disable_synchronize_execution_providers", "1");
session->Run(run_options, io_binding);
```
By disabling the synchronization on the inference the user has to take care of synchronizing the compute stream after execution.
This feature should only be used with device local memory or an ORT Value allocated in [pinned memory](https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/), otherwise the issued download will be blocking and not behave as desired.

### Convolution-heavy models

Expand Down