Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CUDNN Frontend and use for CUDA NN Convolution #19470

Merged
merged 8 commits into from
Aug 2, 2024

Conversation

JTischbein
Copy link
Contributor

@JTischbein JTischbein commented Feb 8, 2024

Description

Added CUDNN Frontend and used it for NHWC convolutions, and optionally fuse activation.

Backward compatible

  • For model existed with FusedConv, model can still run.
  • If ORT is built with cuDNN 8, cuDNN frontend will not be built into binary. Old kernels (using cudnn backend APIs) are used.

Major Changes

  • For cuDNN 9, we will enable cudnn frontend to fuse convolution and bias when a provider option fuse_conv_bias=1.
  • Remove the fusion of FusedConv from graph transformer for CUDA provider, so there will not be FusedConv be added to graph for CUDA EP in the future.
  • Update cmake files regarding to cudnn settings. The search order of CUDNN installation in build are like the following:
    • environment variable CUDNN_PATH
    • onnxruntime_CUDNN_HOME cmake extra defines. If a build starts from build.py/build.sh, user can pass it through --cudnn_home parameter, or by environment variable CUDNN_HOME if --cudnn_home not used.
    • cudnn python package installation directory like python3.xx/site-packages/nvidia/cudnn
    • CUDA installation path

Potential Issues

  • If ORT is built with cuDNN 8, FusedConv fusion is no longer done automatically, so some model might have performance regression. If user still wants FusedConv operator for performance reason, they can still have multiple ways to walkaround: like use older version of onnxruntime; or use older version of ORT to save optimized onnx, then run with latest version of ORT. We believe that majority users will run cudnn 9 when 1.20 release, because cudnn 9 has been released in ORT and PyTorch for 3 months at that time.
  • cuDNN frontend uses TF32 by default, and user cannot disable TF32 in convolution through the use_tf32 cuda provider option. If user encounters accuracy issue (like for testing purpose), user has to set environment variable NVIDIA_TF32_OVERRIDE=0 to disable TF32, or set larger tolerance in testing. Need update the document of use_tf32 about TF32 impact in convolution.

Follow ups

This is one of PRs that target to enable NHWC convolution in CUDA EP by default if device supports it. There are other changes will follow up to make it possible.
(1) Enable prefer_nhwc by default for device with sm >= 70.
(2) Change fuse_conv_bias=1 by default after more testing.
(3) Add other NHWC operators (like Resize or UpSample).

Motivation and Context

The new CUDNN Frontend library provides the functionality to fuse operations and provides new heuristics for kernel selection. Here it fuses the convolution with the pointwise bias operation. On the NVIDIA ResNet50 we get a performance boost from 49.1144 ms to 42.4643 ms per inference on a 2560x1440 input (onnxruntime_perf_test -e cuda -I -q -r 100-d 1 -i 'prefer_nhwc|1' resnet50.onnx).

@gedoensmax
Copy link
Contributor

@hariharans29 Could you help review this ? And ideally even help resolve the compile problems we are seeing with cuDNN frontend.

@tianleiwu
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Big Models

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@JTischbein
Copy link
Contributor Author

@microsoft-github-policy-service agree company="NVIDIA"

1 similar comment
@JTischbein
Copy link
Contributor Author

@microsoft-github-policy-service agree company="NVIDIA"

@tianleiwu
Copy link
Contributor

tianleiwu commented Feb 8, 2024

Any idea why there is failure in creating graph execution plan for stable diffusion 1.5 VAE decoder:

2024-02-08 17:11:31.766176892 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running NhwcConv node. Name:'NhwcConv_35-/decoder/conv_out/Conv' Status Message: CUDNN_FE failure 6: GRAPH_EXECUTION_PLAN_CREATION_FAILED ; GPU=0 ; hostname=79ffae696371 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/nn/conv.cc ; line=400 ; expr=s_.cudnn_fe_graph->check_support(handle);

The node is like the following (the weight of shape 3x3x3x128 is in NHWC format):
image

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Android CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

Copy link

Pull request contains merge conflicts.

2 similar comments
Copy link

Pull request contains merge conflicts.

Copy link

Pull request contains merge conflicts.

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Android CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

1 similar comment
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Android CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

1 similar comment
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu
Copy link
Contributor

There are build error in Windows:
D:\a_work\1\b\RelWithDebInfo_deps\cudnn_frontend-src\include\cudnn_frontend\graph_interface.h(444,19): Error C2248: 'cudnn_frontend::graph::Layernorm_attributes::forward_phase': cannot access private member declared in class 'cudnn_frontend::graph::Layernorm_attributes'

Some test error for NhwcConv:

Run failed but expected success: Non-zero status code returned while running NhwcConv node. Name:'node1' Status Message: CUDNN failure 3: CUDNN_STATUS_BAD_PARAM ; GPU=0 ; hostname=8dc224a231be ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/nn/conv.cc ; line=548 ; expr=cudnnConvolutionForward(cudnn_handle, &alpha, s_.x_tensor, s_.x_data, s_.w_desc, s_.w_data, s_.conv_desc, s_.algo, workspace.get(), s_.workspace_bytes, &beta, s_.y_tensor, s_.y_data);
Google Test trace:
/onnxruntime_src/onnxruntime/test/providers/base_tester.cc:791: registered execution providers: CUDAExecutionProvider

@JTischbein JTischbein requested review from a team as code owners March 1, 2024 23:59
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Linux Android Emulator QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@tianleiwu tianleiwu requested a review from jywu-msft August 2, 2024 17:02
Copy link
Contributor

@pranavsharma pranavsharma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving for admin

@tianleiwu tianleiwu requested a review from snnn August 2, 2024 17:32
@hariharans29
Copy link
Member

Could you please update the PR description with some high level details. Some questions I had were-

  1. Is the cudnn frontend fork mentioned in the description still valid?

  2. What fusions will be enabled/disabled for NHWC/NCHW Conv for CUDA EP going forward?
    2.1) What backward compatibility guarantee will be offered for an already optimized frozen graph?

  3. What is the expected execution flow difference when using CUDA EP with CuDNN 8/ CuDNN 9+?

Copy link
Contributor

@jchen351 jchen351 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved for ES

@tianleiwu
Copy link
Contributor

Could you please update the PR description with some high level details. Some questions I had were-

  1. Is the cudnn frontend fork mentioned in the description still valid?

It is not needed, official 1.5.2 cudnn frontend is used.

  1. What fusions will be enabled/disabled for NHWC/NCHW Conv for CUDA EP going forward?
    2.1) What backward compatibility guarantee will be offered for an already optimized frozen graph?
  2. What is the expected execution flow difference when using CUDA EP with CuDNN 8/ CuDNN 9+?

Updated the description based on my understanding.

@tianleiwu tianleiwu merged commit 1391354 into microsoft:main Aug 2, 2024
77 checks passed
@hariharans29
Copy link
Member

With cudnn 9, the activation after the Conv will be fused going forward for NHWC and NCHW ?

@gedoensmax
Copy link
Contributor

@hariharans29 As there were some regressions it is not enabled per default. We would hope to be able to enable this in the future as the default option as cuDNN's runtime compiled kernels catch up.

tianleiwu added a commit that referenced this pull request Aug 9, 2024
### Description
* Fix migraphx build error caused by
#21598:
Add a conditional compile on code block that depends on ROCm >= 6.2.
Note that the pipeline uses ROCm 6.0.

Unblock orttraining-linux-gpu-ci-pipeline and
orttraining-ortmodule-distributed and orttraining-amd-gpu-ci-pipeline
pipelines:
* Disable a model test in linux GPU training ci pipelines caused by
#19470:
Sometime, cudnn frontend throws exception that cudnn graph does not
support a Conv node of keras_lotus_resnet3D model on V100 GPU.
Note that same test does not throw exception in other GPU pipelines. The
failure might be related to cudnn 8.9 and V100 GPU used in the pipeline
(Amper GPUs and cuDNN 9.x do not have the issue).
The actual fix requires fallback logic, which will take time to
implement, so we temporarily disable the test in training pipelines.
* Force install torch for cuda 11.8. (The docker has torch 2.4.0 for
cuda 12.1 to build torch extension, which it is not compatible cuda
11.8). Note that this is temporary walkround. More elegant fix is to
make sure right torch version in docker build step, that might need
update install_python_deps.sh and corresponding requirements.txt.
* Skip test_gradient_correctness_conv1d since it causes segment fault.
Root cause need more investigation (maybe due to cudnn frontend as
well).
* Skip test_aten_attention since it causes assert failure. Root cause
need more investigation (maybe due to torch version).
* Skip orttraining_ortmodule_distributed_tests.py since it has error
that compiler for torch extension does not support c++17. One possible
fix it to set the following compile argument inside setup.py of
extension fused_adam: extra_compile_args['cxx'] = ['-std=c++17'].
However, due to the urgency of unblocking the pipelines, just disable
the test for now.
* skip test_softmax_bf16_large. For some reason,
torch.cuda.is_bf16_supported() returns True in V100 with torch 2.3.1, so
the test was run in CI, but V100 does not support bf16 natively.
* Fix typo of deterministic

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
prathikr pushed a commit that referenced this pull request Aug 9, 2024
* Fix migraphx build error caused by
#21598:
Add a conditional compile on code block that depends on ROCm >= 6.2.
Note that the pipeline uses ROCm 6.0.

Unblock orttraining-linux-gpu-ci-pipeline and
orttraining-ortmodule-distributed and orttraining-amd-gpu-ci-pipeline
pipelines:
* Disable a model test in linux GPU training ci pipelines caused by
#19470:
Sometime, cudnn frontend throws exception that cudnn graph does not
support a Conv node of keras_lotus_resnet3D model on V100 GPU.
Note that same test does not throw exception in other GPU pipelines. The
failure might be related to cudnn 8.9 and V100 GPU used in the pipeline
(Amper GPUs and cuDNN 9.x do not have the issue).
The actual fix requires fallback logic, which will take time to
implement, so we temporarily disable the test in training pipelines.
* Force install torch for cuda 11.8. (The docker has torch 2.4.0 for
cuda 12.1 to build torch extension, which it is not compatible cuda
11.8). Note that this is temporary walkround. More elegant fix is to
make sure right torch version in docker build step, that might need
update install_python_deps.sh and corresponding requirements.txt.
* Skip test_gradient_correctness_conv1d since it causes segment fault.
Root cause need more investigation (maybe due to cudnn frontend as
well).
* Skip test_aten_attention since it causes assert failure. Root cause
need more investigation (maybe due to torch version).
* Skip orttraining_ortmodule_distributed_tests.py since it has error
that compiler for torch extension does not support c++17. One possible
fix it to set the following compile argument inside setup.py of
extension fused_adam: extra_compile_args['cxx'] = ['-std=c++17'].
However, due to the urgency of unblocking the pipelines, just disable
the test for now.
* skip test_softmax_bf16_large. For some reason,
torch.cuda.is_bf16_supported() returns True in V100 with torch 2.3.1, so
the test was run in CI, but V100 does not support bf16 natively.
* Fix typo of deterministic

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
sumitsays pushed a commit that referenced this pull request Aug 9, 2024
* Fix migraphx build error caused by
#21598:
Add a conditional compile on code block that depends on ROCm >= 6.2.
Note that the pipeline uses ROCm 6.0.

Unblock orttraining-linux-gpu-ci-pipeline and
orttraining-ortmodule-distributed and orttraining-amd-gpu-ci-pipeline
pipelines:
* Disable a model test in linux GPU training ci pipelines caused by
#19470:
Sometime, cudnn frontend throws exception that cudnn graph does not
support a Conv node of keras_lotus_resnet3D model on V100 GPU.
Note that same test does not throw exception in other GPU pipelines. The
failure might be related to cudnn 8.9 and V100 GPU used in the pipeline
(Amper GPUs and cuDNN 9.x do not have the issue).
The actual fix requires fallback logic, which will take time to
implement, so we temporarily disable the test in training pipelines.
* Force install torch for cuda 11.8. (The docker has torch 2.4.0 for
cuda 12.1 to build torch extension, which it is not compatible cuda
11.8). Note that this is temporary walkround. More elegant fix is to
make sure right torch version in docker build step, that might need
update install_python_deps.sh and corresponding requirements.txt.
* Skip test_gradient_correctness_conv1d since it causes segment fault.
Root cause need more investigation (maybe due to cudnn frontend as
well).
* Skip test_aten_attention since it causes assert failure. Root cause
need more investigation (maybe due to torch version).
* Skip orttraining_ortmodule_distributed_tests.py since it has error
that compiler for torch extension does not support c++17. One possible
fix it to set the following compile argument inside setup.py of
extension fused_adam: extra_compile_args['cxx'] = ['-std=c++17'].
However, due to the urgency of unblocking the pipelines, just disable
the test for now.
* skip test_softmax_bf16_large. For some reason,
torch.cuda.is_bf16_supported() returns True in V100 with torch 2.3.1, so
the test was run in CI, but V100 does not support bf16 natively.
* Fix typo of deterministic

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
tianleiwu added a commit that referenced this pull request Aug 15, 2024
### Description
Exclude cuDNN 9 and CUDA 12 DLLs from manylinux wheel to reduce python
package size.

### Motivation and Context

The 1.20.0 ort-nightly-gpu python wheels on linux are suddenly > 800 MB
in size. The wheels built on 1.19 release branch have a size of around
220 MB.

The size change is caused by
#19470.
tianleiwu pushed a commit that referenced this pull request Sep 10, 2024
### Description
Added CUDNN Frontend and used it for NHWC ConvTranspose op including
option for bias fusion. Similar to this [Conv
PR](#19470)

### Backward compatible
If ORT is built with cuDNN 8, cuDNN frontend will not be built into
binary. Old kernels (using cudnn backend APIs) are used.

### Major Changes
For cuDNN 9, we will enable cudnn frontend to fuse data gradient
convolution and bias when a provider option fuse_conv_bias=1.

### Potential Issues
cuDNN frontend uses TF32 by default. It can be disabled using use_tf32
cuda provider option, but in the case cuDNN frontend encounters issues
building an operation graph it will fallback to using TF32.

### Follow ups
This is one of the PRs that target to enable NHWC, here the
ConvTranspose operation in CUDA EP by default if device supports it.
There are other changes will follow up to make it possible.
(1) Enable prefer_nhwc by default for device with sm >= 70.
(2) Change fuse_conv_bias=1 by default after more testing.
(3) Add other NHWC operators (like Resize or UpSample).

### Motivation and Context
The new CUDNN Frontend library provides the functionality to fuse
operations and provides new heuristics for kernel selection. Here it
fuses the convolution data gradient operation (ConvTranspose) with the
pointwise bias operation.

### Minor Change
In the CUDA convolution operation was a small bug when
`GetCudnnConv1dPadToNc1d ` was enabled.
chilo-ms pushed a commit that referenced this pull request Nov 15, 2024
### Description
Fixes build failure for the cuda minimal build




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
[This change](#19470) in
1.20 is causing build failures for the cuda minimal build.
Essentially, some cudnn logic was not guarded by the `USE_CUDA_MINIMAL`.
Also the build is looking for cudnn while in the cuda minimal build it
shouldn't depend on it, resulting in linking error.


cc @gedoensmax @chilo-ms
ishwar-raut1 pushed a commit to ishwar-raut1/onnxruntime that referenced this pull request Nov 19, 2024
### Description
Fixes build failure for the cuda minimal build




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
[This change](microsoft#19470) in
1.20 is causing build failures for the cuda minimal build.
Essentially, some cudnn logic was not guarded by the `USE_CUDA_MINIMAL`.
Also the build is looking for cudnn while in the cuda minimal build it
shouldn't depend on it, resulting in linking error.


cc @gedoensmax @chilo-ms
guschmue pushed a commit that referenced this pull request Dec 2, 2024
### Description
Fixes build failure for the cuda minimal build




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
[This change](#19470) in
1.20 is causing build failures for the cuda minimal build.
Essentially, some cudnn logic was not guarded by the `USE_CUDA_MINIMAL`.
Also the build is looking for cudnn while in the cuda minimal build it
shouldn't depend on it, resulting in linking error.


cc @gedoensmax @chilo-ms
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
### Description
Fixes build failure for the cuda minimal build




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
[This change](microsoft#19470) in
1.20 is causing build failures for the cuda minimal build.
Essentially, some cudnn logic was not guarded by the `USE_CUDA_MINIMAL`.
Also the build is looking for cudnn while in the cuda minimal build it
shouldn't depend on it, resulting in linking error.


cc @gedoensmax @chilo-ms
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
### Description
Fixes build failure for the cuda minimal build




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
[This change](microsoft#19470) in
1.20 is causing build failures for the cuda minimal build.
Essentially, some cudnn logic was not guarded by the `USE_CUDA_MINIMAL`.
Also the build is looking for cudnn while in the cuda minimal build it
shouldn't depend on it, resulting in linking error.


cc @gedoensmax @chilo-ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants