Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConvTranpose using CUDNN Frontend with NHWC support #21752

Merged
merged 5 commits into from
Sep 10, 2024

Conversation

JTischbein
Copy link
Contributor

Description

Added CUDNN Frontend and used it for NHWC ConvTranspose op including option for bias fusion. Similar to this Conv PR

Backward compatible

If ORT is built with cuDNN 8, cuDNN frontend will not be built into binary. Old kernels (using cudnn backend APIs) are used.

Major Changes

For cuDNN 9, we will enable cudnn frontend to fuse data gradient convolution and bias when a provider option fuse_conv_bias=1.

Potential Issues

cuDNN frontend uses TF32 by default. It can be disabled using use_tf32 cuda provider option, but in the case cuDNN frontend encounters issues building an operation graph it will fallback to using TF32.

Follow ups

This is one of the PRs that target to enable NHWC, here the ConvTranspose operation in CUDA EP by default if device supports it. There are other changes will follow up to make it possible.
(1) Enable prefer_nhwc by default for device with sm >= 70.
(2) Change fuse_conv_bias=1 by default after more testing.
(3) Add other NHWC operators (like Resize or UpSample).

Motivation and Context

The new CUDNN Frontend library provides the functionality to fuse operations and provides new heuristics for kernel selection. Here it fuses the convolution data gradient operation (ConvTranspose) with the pointwise bias operation.

Minor Change

In the CUDA convolution operation was a small bug when GetCudnnConv1dPadToNc1d was enabled.

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Android CI Pipeline

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

// see Conv<T, NHWC>::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details.
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
// add fake H dimension
const auto insert_at = NHWC ? 1 : 2;

Check warning

Code scanning / PREfast

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5).
// see Conv<T, NHWC>::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details.
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
// add fake H dimension
const auto insert_at = NHWC ? 1 : 2;

Check warning

Code scanning / PREfast

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5).
// see Conv<T, NHWC>::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details.
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
// add fake H dimension
const auto insert_at = NHWC ? 1 : 2;

Check warning

Code scanning / PREfast

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5).
w_dims.insert(w_dims.begin() + insert_at, 1);
} else {
// add fake W dimension
const auto insert_at = NHWC ? 2 : 3;

Check warning

Code scanning / PREfast

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5).
w_dims.insert(w_dims.begin() + insert_at, 1);
} else {
// add fake W dimension
const auto insert_at = NHWC ? 2 : 3;

Check warning

Code scanning / PREfast

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5).
w_dims.insert(w_dims.begin() + insert_at, 1);
} else {
// add fake W dimension
const auto insert_at = NHWC ? 2 : 3;

Check warning

Code scanning / PREfast

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'insert_at' can be computed at compile-time. Consider using constexpr (con.5).

ConvTransposeAttributes::Prepare p;
// PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels'
const bool transposed_input_channels = false;

Check warning

Code scanning / PREfast

The const variable 'transposed_input_channels' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'transposed_input_channels' can be computed at compile-time. Consider using constexpr (con.5).

ConvTransposeAttributes::Prepare p;
// PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels'
const bool transposed_input_channels = false;

Check warning

Code scanning / PREfast

The const variable 'transposed_input_channels' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'transposed_input_channels' can be computed at compile-time. Consider using constexpr (con.5).

ConvTransposeAttributes::Prepare p;
// PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels'
const bool transposed_input_channels = false;

Check warning

Code scanning / PREfast

The const variable 'transposed_input_channels' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'transposed_input_channels' can be computed at compile-time. Consider using constexpr (con.5).
@tianleiwu
Copy link
Contributor

/azp run Linux GPU CI Pipeline, Windows GPU TensorRT CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@tianleiwu
Copy link
Contributor

@JTischbein, there were some test errors in pipelines. Did you have chance to take a look?

@JTischbein
Copy link
Contributor Author

@JTischbein, there were some test errors in pipelines. Did you have chance to take a look?

The test errors were not reproducible for me. We are currently testing other GPUs, I will keep you updated. For the build error I will add another [[maybe_unused]] and push it now.

@tianleiwu
Copy link
Contributor

/azp run Linux GPU CI Pipeline, Windows GPU TensorRT CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Linux GPU CI Pipeline, Windows GPU TensorRT CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Linux GPU CI Pipeline, Windows GPU TensorRT CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu tianleiwu merged commit 20d9464 into microsoft:main Sep 10, 2024
79 of 81 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants