Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CUDA a NHWC EP #17200

Merged
merged 25 commits into from
Oct 16, 2023
Merged

Make CUDA a NHWC EP #17200

merged 25 commits into from
Oct 16, 2023

Conversation

gedoensmax
Copy link
Contributor

@gedoensmax gedoensmax commented Aug 17, 2023

Description

CUDA inference speed heavily relies on Tensor Cores. To have tensor cores achieve the optimal throughput they require the data layout to be NHWC rather than NCHW.

Motivation and Context

Especially for convolutional networks this is very important. I will illustrate this using a very simple network:

import torch
import torch.nn as nn

class Net1(nn.Module):

    def __init__(self):
        super(Net1, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.m = nn.ModuleList([
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
        ])
    def forward(self, x):
        for module in self.m:
            x = module(x)
        return x


if __name__ == "__main__":
    dtype = torch.half
    device = "cuda"

    dummy_input = torch.randn(8, 8, 512, 512, dtype=dtype, device=device)
    model = Net1().to(dtype=dtype, device=device)
    input_names = ["input1"]
    output_names = ["output1"]
    torch.onnx.export(model, dummy_input, "test.onnx",
                      input_names=input_names, output_names=output_names)

I profiled the launch of ./build/RelWithDebInfo/onnxruntime_perf_test -e cuda -I -q -t 5 test.onnx using sys and nvtx ranges.
Current master launches below kernels:
image

If I add the introduced -l flag we see below kernels:
image

Notice the missing NCHW<>NHWC kernels per operation. The layout optimizer introduced a transpose op as first and last op of the whole network. The op_generic_tensor_kernel shows the bias used which should also be optimized out next.

Measured across some very basic models:

CUDA EP NCHW [ms] NHWC [ms] Speedup
-e cuda -t 5 -q -e cuda -t 5 -q -l
resnet101-v2-7_bs8_fp16 18.33 13.07 1.4
resnet101-v2-7_bs8 21.8 12.06 1.81
test 102.07 73.62 1.39
Average speedup: 1.53

Outlook

Next the mission will be to first write a templated unit test to check for correctness of NHWC vs NCHW ops. After that we have to transition more ops to measure perf improvements on a broader range of models. Currently this is not easily possible as we can do not support all ops in the NHWC domain.

@gedoensmax
Copy link
Contributor Author

@skottmckay in case you have some feedback or ideas on unit testing this transition or even allowing a "partial NHWC" provider. But I think that's not possible as the layout transformer asks for the whole provider rather than for each op right ?

benchmark.py Fixed Show fixed Hide fixed
@jywu-msft
Copy link
Member

@hariharans29
Copy link
Member

@skottmckay in case you have some feedback or ideas on unit testing this transition or even allowing a "partial NHWC" provider. But I think that's not possible as the layout transformer asks for the whole provider rather than for each op right ?

Are you wondering how to unit test an NHWC CUDA op ?
Just thinking out aloud, let us say for testing purpose, you create a single node model (eg Conv), we should be able to instantiate a session for the same model in both NCHW (default for CUDA EP) and NHWC (with your new provider option set). I am guessing for the NHWC, we should see a Transpose before and a Transpose after inserted by the layout transformer. If we provide an input to the NCHW session and provide the same input to the NHWC session, we should get back the same results in both cases. Would this work to unit test a single op ?

@hariharans29
Copy link
Member

hariharans29 commented Aug 18, 2023

Adding a comment to track documenting this feature here once this lands. Sample PR: #10859

@hariharans29
Copy link
Member

hariharans29 commented Aug 18, 2023

Measured across some very basic models:

CUDA EP NCHW [ms] NHWC [ms] Speedup
-e cuda -t 5 -q -e cuda -t 5 -q -l
resnet101-v2-7_bs8_fp16 18.33 13.07 1.4
resnet101-v2-7_bs8 21.8 12.06 1.81
test 102.07 73.62 1.39
Average speedup: 1.53

Just curious - if NHWC is needed to leverage tensor cores better, how are we seeing the speedup for resnet101-v2-7_bs8 (which I assume is fp32) while using NHWC when compared to NCHW ? Tensor cores are primarliy used for fp16 MMA right ?

@twoapples1
Copy link

@skottmckay in case you have some feedback or ideas on unit testing this transition or even allowing a "partial NHWC" provider. But I think that's not possible as the layout transformer asks for the whole provider rather than for each op right ?

Hello, i have some questions about this test.  It is undeniable that convolution calculations with NHWC data format are faster than with NCHW data format. However, in deep learning models, convolution is not the only operator, and there are other operators such as upsample and maxpool. Considering that onnx currently only supports the NCHW data format, I think there are two methods for reference:
The first method is to perform data conversion from NHWC to NCHW before and after convolution, and use the NHWC data format for convolution calculation, just like the first test. However, frequent data conversion like this would inevitably incur time overhead;
The second method is to convert the calculation of the entire graph into NHWC data format. In this case, only the first operator and the last operator would require data format conversion However,we need to add the corresponding calculation method for NHWC data format for each operator, and when parse the onnx model, we need to pay attention to the weight parameter layout in the onnx model. It's a heavy work. Are both of these methods correct? Based on the test, do you intend to choose the second method?

@gedoensmax
Copy link
Contributor Author

gedoensmax commented Aug 21, 2023

session for the same model in both NCHW (default for CUDA EP) and NHWC (with your new provider option set)

@hariharans29 exactly that is my plan and I think with the link you sent me on teams:

TEST(AttentionTest, AttentionPastState_dynamic) {
// ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test.
// Do not run this test unless TF32 is disabled explicitly.
if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault<int>("NVIDIA_TF32_OVERRIDE", 1) != 0) {
GTEST_SKIP() << "Skipping AttentionPastState_dynamic in A100 since TF32 is enabled";
return;
}
// create rand inputs
RandomValueGenerator random{};
std::vector<int64_t> input_dims{2, 5, 768};
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 0.3f);
std::vector<int64_t> weight_dims{768, 2304};
std::vector<float> weight_data = random.Gaussian<float>(weight_dims, 0.0f, 0.3f);
std::vector<int64_t> bias_dims{2304};
std::vector<float> bias_data = random.Gaussian<float>(bias_dims, 0.0f, 0.3f);
std::vector<int64_t> past_dims{2, 2, 12, 15, 64};
std::vector<float> past_data = random.Gaussian<float>(past_dims, 0.0f, 0.3f);
OpTester test("Attention", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("num_heads", 12);
test.AddAttribute<int64_t>("unidirectional", 1);
test.AddInput<float>("input", input_dims, input_data);
test.AddInput<float>("weight", weight_dims, weight_data);
test.AddInput<float>("bias", bias_dims, bias_data);
test.AddOptionalInputEdge<int32_t>();
test.AddInput<float>("past", past_dims, past_data);
test.AddReferenceOutputs("testdata/attention_past_state.onnx", 0.005f);
test.Run();
}
that should easily be possible.

Also thanks for the other comments like workspace and conv algo - I think I have to make this passing via -I more general anyway.

@gedoensmax
Copy link
Contributor Author

@twoapples1 Ideally we want to transition all ops to NHWC to have as little conversions as possible inside the model. As you see from the nsys traces that I shared as a screen shot, currently the NCHW to NHWC conversions and back can happen every single node !
So yes my intention is to convert all ops to NHWC and transpose weights once upon loading a network.

@gedoensmax
Copy link
Contributor Author

Just curious - if NHWC is needed to leverage tensor cores better, how are we seeing the speedup for resnet101-v2-7_bs8 (which I assume is fp32) while using NHWC when compared to NCHW ? Tensor cores are primarliy used for fp16 MMA right ?

You are absolutely right, but as I tested on an Ada series GPU I also get to enjoy TF32 acceleration. On a Turing series NHWC is actually a little slower than NCHW, but not because of the kernels, but because FusedConv will be used for NCHW and for NHWC it is not selected.

@hariharans29
Copy link
Member

Just curious - if NHWC is needed to leverage tensor cores better, how are we seeing the speedup for resnet101-v2-7_bs8 (which I assume is fp32) while using NHWC when compared to NCHW ? Tensor cores are primarliy used for fp16 MMA right ?

You are absolutely right, but as I tested on an Ada series GPU I also get to enjoy TF32 acceleration. On a Turing series NHWC is actually a little slower than NCHW, but not because of the kernels, but because FusedConv will be used for NCHW and for NHWC it is not selected.

It seems like CuDNN has better implementations in general for Conv with NHWC (irrespective of the data type) ?

@skottmckay
Copy link
Contributor

@skottmckay Scott McKay FTE in case you have some feedback or ideas on unit testing this transition or even allowing a "partial NHWC" provider. But I think that's not possible as the layout transformer asks for the whole provider rather than for each op right ?

One option would be to register another EP that shares a lot of the implementation with the existing EP but asks for NHWC layout. If it is higher priority to the existing EP it would get first chance to request nodes and have them converted to NWHC. Remaining nodes could be taken by the existing EP. Probably slightly confusing to have 2 CUDA EPs though so depends on whether that will be a short term or long term situation. May also cause other complexities to have 2 EPs (e.g. would it try to synchronize between nodes assigned different EPs) that would take time to work though.

Alternatively the EP interface could be expanded to try and support this. Not clear how that could/should look though. Short term it may be better to manually add the handling in layout_transformation.cc until we see if any other EPs would ever need this. Could maybe expand the GetEPLayoutSensitiveOps in this draft PR. May have time to get back to that PR in the next couple of weeks. If you removed operators from the layout sensitive set for the EP the layout transformation would not convert them to NHWC. So the EP would say NHWC is its preferred layout, but only a subset of layout sensitive nodes would be converted.

@skottmckay
Copy link
Contributor

Alternatively the EP interface could be expanded to try and support this. Not clear how that could/should look though. Short term it may be better to manually add the handling in layout_transformation.cc until we see if any other EPs would ever need this. Could maybe expand the GetEPLayoutSensitiveOps in this draft PR.

It might be better to abstract this out a little in TransformLayoutForEP and instead have a function like bool ConvertLayoutForNode(const Node& node) instead of directly doing a lookup of OpType() in the set of layout sensitive ops like we do currently:

const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
// to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back.
for (auto& node : api_graph->Nodes()) {
if (layout_sensitive_ops.count(node->OpType())) {
if (node->GetExecutionProviderType() != execution_provider.Type()) {
continue;

The default implementation of ConvertLayoutForNode can do the existing lookup in the set of layout sensitive ops, but that also gives a place where EP specific things can be plugged in. In the case of the CUDA EP it can return true for the subset of layout sensitive nodes that it wants converted. Passing in the Node allows you to check the op type, the domain and the EP it is assigned to.

I'd start with this EP specific logic being in layout_transformation.cc, but if necessary (i.e. more EPs need to control this behavior) we could update the EP API to allow the EP to optionally provide a ConvertLayoutForNode delegate.

@gedoensmax
Copy link
Contributor Author

It might be better to abstract this out a little in TransformLayoutForEP and instead have a function like bool ConvertLayoutForNode(const Node& node) instead of directly doing a lookup of OpType() in the set of layout sensitive ops like we do currently:

A big +1 in my opinion as it would also allow for an easier transition from one layout to the other. I am open for discussions to help out with that but for now I would say I'll probably leave the design of such an API to the core CUDA EP devs here.

@gedoensmax gedoensmax changed the title Draft: Make CUDA a NHWC EP Make CUDA a NHWC EP Aug 30, 2023
@hariharans29
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, Linux QNN CI Pipeline

@hariharans29
Copy link
Member

/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows ARM64 QNN CI Pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed, ONNX Runtime React Native CI Pipeline, Windows x64 QNN CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@hariharans29
Copy link
Member

@gedoensmax - Can you please resolve the conflict ?

@gedoensmax
Copy link
Contributor Author

Done.

@hariharans29
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, Linux QNN CI Pipeline

@hariharans29
Copy link
Member

/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows ARM64 QNN CI Pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed, ONNX Runtime React Native CI Pipeline, Windows x64 QNN CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@hariharans29 hariharans29 merged commit 7c17e33 into microsoft:main Oct 16, 2023
65 checks passed
hariharans29 added a commit that referenced this pull request Oct 17, 2023
### Description
This PR:

(1) Fixes AMD builds after #17200 broke them (Need to remember to run
AMD builds while trying to merge external CUDA PRs next time)

(2) Turn on the NHWC CUDA feature in the Linux GPU CI. The extra time
spent in building a few more files and running a few more tests will not
be much.

Test Linux GPU CI run :
https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770

### Motivation and Context
Keep the NHWC CUDA ops tested
(#17200) and guard against
regressions
jchen351 pushed a commit that referenced this pull request Oct 18, 2023
### Description

CUDA inference speed heavily relies on Tensor Cores. To have tensor
cores achieve the optimal throughput they require the data layout to be
NHWC rather than NCHW.

### Motivation and Context


Especially for convolutional networks this is very important. I will
illustrate this using a very simple network:
```
import torch
import torch.nn as nn

class Net1(nn.Module):

    def __init__(self):
        super(Net1, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.m = nn.ModuleList([
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
        ])
    def forward(self, x):
        for module in self.m:
            x = module(x)
        return x


if __name__ == "__main__":
    dtype = torch.half
    device = "cuda"

    dummy_input = torch.randn(8, 8, 512, 512, dtype=dtype, device=device)
    model = Net1().to(dtype=dtype, device=device)
    input_names = ["input1"]
    output_names = ["output1"]
    torch.onnx.export(model, dummy_input, "test.onnx",
                      input_names=input_names, output_names=output_names)
```

I profiled the launch of `./build/RelWithDebInfo/onnxruntime_perf_test
-e cuda -I -q -t 5 test.onnx` using sys and nvtx ranges.
Current master launches below kernels: 

![image](https://github.com/microsoft/onnxruntime/assets/44298237/81655fce-0f8e-4f78-9335-b858a8c8977b)

If I add the introduced `-l` flag we see below kernels:

![image](https://github.com/microsoft/onnxruntime/assets/44298237/fceb5d6f-c12d-442b-b15a-948797630008)

Notice the missing NCHW<>NHWC kernels per operation. The layout
optimizer introduced a transpose op as first and last op of the whole
network. The `op_generic_tensor_kernel` shows the bias used which should
also be optimized out next.

Measured across some very basic models:
| CUDA EP | **NCHW** [ms] | **NHWC** [ms] | Speedup |

|:------------------------|--------------------------------------:|-----------------------------------------:|------------------:|
|                         |  -e cuda -t 5 -q |   -e cuda -t 5 -q -l | |
| resnet101-v2-7_bs8_fp16 | 18.33 | 13.07 | 1.4 |
| resnet101-v2-7_bs8 | 21.8 | 12.06 | 1.81 |
| test | 102.07 | 73.62 | 1.39 |
Average speedup: 1.53

## Outlook

Next the mission will be to first write a templated unit test to check
for correctness of NHWC vs NCHW ops. After that we have to transition
more ops to measure perf improvements on a broader range of models.
Currently this is not easily possible as we can do not support all ops
in the NHWC domain.

---------

Co-authored-by: Tianlei Wu <[email protected]>
jchen351 pushed a commit that referenced this pull request Oct 18, 2023
### Description
This PR:

(1) Fixes AMD builds after #17200 broke them (Need to remember to run
AMD builds while trying to merge external CUDA PRs next time)

(2) Turn on the NHWC CUDA feature in the Linux GPU CI. The extra time
spent in building a few more files and running a few more tests will not
be much.

Test Linux GPU CI run :
https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770

### Motivation and Context
Keep the NHWC CUDA ops tested
(#17200) and guard against
regressions
@YangQiangli
Copy link

Description

CUDA inference speed heavily relies on Tensor Cores. To have tensor cores achieve the optimal throughput they require the data layout to be NHWC rather than NCHW.

Motivation and Context

Especially for convolutional networks this is very important. I will illustrate this using a very simple network:

import torch
import torch.nn as nn

class Net1(nn.Module):

    def __init__(self):
        super(Net1, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.m = nn.ModuleList([
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
        ])
    def forward(self, x):
        for module in self.m:
            x = module(x)
        return x


if __name__ == "__main__":
    dtype = torch.half
    device = "cuda"

    dummy_input = torch.randn(8, 8, 512, 512, dtype=dtype, device=device)
    model = Net1().to(dtype=dtype, device=device)
    input_names = ["input1"]
    output_names = ["output1"]
    torch.onnx.export(model, dummy_input, "test.onnx",
                      input_names=input_names, output_names=output_names)

I profiled the launch of ./build/RelWithDebInfo/onnxruntime_perf_test -e cuda -I -q -t 5 test.onnx using sys and nvtx ranges. Current master launches below kernels: image

If I add the introduced -l flag we see below kernels: image

Notice the missing NCHW<>NHWC kernels per operation. The layout optimizer introduced a transpose op as first and last op of the whole network. The op_generic_tensor_kernel shows the bias used which should also be optimized out next.

Measured across some very basic models:

CUDA EP NCHW [ms] NHWC [ms] Speedup
-e cuda -t 5 -q -e cuda -t 5 -q -l
resnet101-v2-7_bs8_fp16 18.33 13.07 1.4
resnet101-v2-7_bs8 21.8 12.06 1.81
test 102.07 73.62 1.39
Average speedup: 1.53

Outlook

Next the mission will be to first write a templated unit test to check for correctness of NHWC vs NCHW ops. After that we have to transition more ops to measure perf improvements on a broader range of models. Currently this is not easily possible as we can do not support all ops in the NHWC domain.

Hello, I've noticed that NHWC has been merged into the main branch. I tried to enable it by adding "--cmake_extra_defines onnxruntime_USE_CUDA_NHWC_OPS=ON" and compiling. However, when I attempt to test the performance using "onnxruntime_perf_test" with "-e cuda -t 5 -q -l", it gives an error saying there is no "-l" configuration. How can I use onnxruntime_perf_test to test the performance of NHWC?

@gedoensmax
Copy link
Contributor Author

Oh sorry i did not update the description. The option for that will be '-i "prefer_nhwc|1". @YangQiangli

kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
<!-- Describe your changes. -->
- Treat Resize as layout sensitive by default
- whilst the ONNX spec does not specify a layout, EPs tend to implement
only one
- add second usage in L2 of TransposeOptimizer to plugin the ability to
push a Transpose through a Resize assigned to the CPU EP
- Allow EP specific logic for changes the ops considered to be layout
sensitive to be plugged in
  - expected usage is for microsoft#17200 


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Finish simplifying/clarifying transpose optimization and layout
transformation that was proposed in microsoft#15552. This PR along with microsoft#17618
should complete the changes.

---------

Co-authored-by: Edward Chen <[email protected]>
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description

CUDA inference speed heavily relies on Tensor Cores. To have tensor
cores achieve the optimal throughput they require the data layout to be
NHWC rather than NCHW.

### Motivation and Context


Especially for convolutional networks this is very important. I will
illustrate this using a very simple network:
```
import torch
import torch.nn as nn

class Net1(nn.Module):

    def __init__(self):
        super(Net1, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.m = nn.ModuleList([
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False),
        ])
    def forward(self, x):
        for module in self.m:
            x = module(x)
        return x


if __name__ == "__main__":
    dtype = torch.half
    device = "cuda"

    dummy_input = torch.randn(8, 8, 512, 512, dtype=dtype, device=device)
    model = Net1().to(dtype=dtype, device=device)
    input_names = ["input1"]
    output_names = ["output1"]
    torch.onnx.export(model, dummy_input, "test.onnx",
                      input_names=input_names, output_names=output_names)
```

I profiled the launch of `./build/RelWithDebInfo/onnxruntime_perf_test
-e cuda -I -q -t 5 test.onnx` using sys and nvtx ranges.
Current master launches below kernels: 

![image](https://github.com/microsoft/onnxruntime/assets/44298237/81655fce-0f8e-4f78-9335-b858a8c8977b)

If I add the introduced `-l` flag we see below kernels:

![image](https://github.com/microsoft/onnxruntime/assets/44298237/fceb5d6f-c12d-442b-b15a-948797630008)

Notice the missing NCHW<>NHWC kernels per operation. The layout
optimizer introduced a transpose op as first and last op of the whole
network. The `op_generic_tensor_kernel` shows the bias used which should
also be optimized out next.

Measured across some very basic models:
| CUDA EP | **NCHW** [ms] | **NHWC** [ms] | Speedup |

|:------------------------|--------------------------------------:|-----------------------------------------:|------------------:|
|                         |  -e cuda -t 5 -q |   -e cuda -t 5 -q -l | |
| resnet101-v2-7_bs8_fp16 | 18.33 | 13.07 | 1.4 |
| resnet101-v2-7_bs8 | 21.8 | 12.06 | 1.81 |
| test | 102.07 | 73.62 | 1.39 |
Average speedup: 1.53

## Outlook

Next the mission will be to first write a templated unit test to check
for correctness of NHWC vs NCHW ops. After that we have to transition
more ops to measure perf improvements on a broader range of models.
Currently this is not easily possible as we can do not support all ops
in the NHWC domain.

---------

Co-authored-by: Tianlei Wu <[email protected]>
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
…oft#17972)

### Description
This PR:

(1) Fixes AMD builds after microsoft#17200 broke them (Need to remember to run
AMD builds while trying to merge external CUDA PRs next time)

(2) Turn on the NHWC CUDA feature in the Linux GPU CI. The extra time
spent in building a few more files and running a few more tests will not
be much.

Test Linux GPU CI run :
https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770

### Motivation and Context
Keep the NHWC CUDA ops tested
(microsoft#17200) and guard against
regressions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants