-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make CUDA a NHWC EP #17200
Make CUDA a NHWC EP #17200
Conversation
@skottmckay in case you have some feedback or ideas on unit testing this transition or even allowing a "partial NHWC" provider. But I think that's not possible as the layout transformer asks for the whole provider rather than for each op right ? |
+@pranavsharma, +@hariharans29, +@tianleiwu, +@yufenglee FYI |
Are you wondering how to unit test an NHWC CUDA op ? |
Just curious - if NHWC is needed to leverage tensor cores better, how are we seeing the speedup for |
Hello, i have some questions about this test. It is undeniable that convolution calculations with NHWC data format are faster than with NCHW data format. However, in deep learning models, convolution is not the only operator, and there are other operators such as upsample and maxpool. Considering that onnx currently only supports the NCHW data format, I think there are two methods for reference: |
@hariharans29 exactly that is my plan and I think with the link you sent me on teams: onnxruntime/onnxruntime/test/contrib_ops/attention_op_test.cc Lines 2015 to 2049 in d65aa54
Also thanks for the other comments like workspace and conv algo - I think I have to make this passing via |
@twoapples1 Ideally we want to transition all ops to NHWC to have as little conversions as possible inside the model. As you see from the nsys traces that I shared as a screen shot, currently the NCHW to NHWC conversions and back can happen every single node ! |
You are absolutely right, but as I tested on an Ada series GPU I also get to enjoy TF32 acceleration. On a Turing series NHWC is actually a little slower than NCHW, but not because of the kernels, but because FusedConv will be used for NCHW and for NHWC it is not selected. |
It seems like CuDNN has better implementations in general for Conv with NHWC (irrespective of the data type) ? |
One option would be to register another EP that shares a lot of the implementation with the existing EP but asks for NHWC layout. If it is higher priority to the existing EP it would get first chance to request nodes and have them converted to NWHC. Remaining nodes could be taken by the existing EP. Probably slightly confusing to have 2 CUDA EPs though so depends on whether that will be a short term or long term situation. May also cause other complexities to have 2 EPs (e.g. would it try to synchronize between nodes assigned different EPs) that would take time to work though. Alternatively the EP interface could be expanded to try and support this. Not clear how that could/should look though. Short term it may be better to manually add the handling in layout_transformation.cc until we see if any other EPs would ever need this. Could maybe expand the GetEPLayoutSensitiveOps in this draft PR. May have time to get back to that PR in the next couple of weeks. If you removed operators from the layout sensitive set for the EP the layout transformation would not convert them to NHWC. So the EP would say NHWC is its preferred layout, but only a subset of layout sensitive nodes would be converted. |
It might be better to abstract this out a little in TransformLayoutForEP and instead have a function like onnxruntime/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc Lines 69 to 75 in 38ea8c3
The default implementation of ConvertLayoutForNode can do the existing lookup in the set of layout sensitive ops, but that also gives a place where EP specific things can be plugged in. In the case of the CUDA EP it can return true for the subset of layout sensitive nodes that it wants converted. Passing in the Node allows you to check the op type, the domain and the EP it is assigned to. I'd start with this EP specific logic being in layout_transformation.cc, but if necessary (i.e. more EPs need to control this behavior) we could update the EP API to allow the EP to optionally provide a ConvertLayoutForNode delegate. |
A big +1 in my opinion as it would also allow for an easier transition from one layout to the other. I am open for discussions to help out with that but for now I would say I'll probably leave the design of such an API to the core CUDA EP devs here. |
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, Linux QNN CI Pipeline |
/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows ARM64 QNN CI Pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed, ONNX Runtime React Native CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 9 pipeline(s). |
1 similar comment
Azure Pipelines successfully started running 9 pipeline(s). |
@gedoensmax - Can you please resolve the conflict ? |
Done. |
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, Linux QNN CI Pipeline |
/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows ARM64 QNN CI Pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed, ONNX Runtime React Native CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 9 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
### Description This PR: (1) Fixes AMD builds after #17200 broke them (Need to remember to run AMD builds while trying to merge external CUDA PRs next time) (2) Turn on the NHWC CUDA feature in the Linux GPU CI. The extra time spent in building a few more files and running a few more tests will not be much. Test Linux GPU CI run : https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770 ### Motivation and Context Keep the NHWC CUDA ops tested (#17200) and guard against regressions
### Description CUDA inference speed heavily relies on Tensor Cores. To have tensor cores achieve the optimal throughput they require the data layout to be NHWC rather than NCHW. ### Motivation and Context Especially for convolutional networks this is very important. I will illustrate this using a very simple network: ``` import torch import torch.nn as nn class Net1(nn.Module): def __init__(self): super(Net1, self).__init__() # 1 input image channel, 6 output channels, 5x5 square convolution # kernel self.m = nn.ModuleList([ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False), ]) def forward(self, x): for module in self.m: x = module(x) return x if __name__ == "__main__": dtype = torch.half device = "cuda" dummy_input = torch.randn(8, 8, 512, 512, dtype=dtype, device=device) model = Net1().to(dtype=dtype, device=device) input_names = ["input1"] output_names = ["output1"] torch.onnx.export(model, dummy_input, "test.onnx", input_names=input_names, output_names=output_names) ``` I profiled the launch of `./build/RelWithDebInfo/onnxruntime_perf_test -e cuda -I -q -t 5 test.onnx` using sys and nvtx ranges. Current master launches below kernels: ![image](https://github.com/microsoft/onnxruntime/assets/44298237/81655fce-0f8e-4f78-9335-b858a8c8977b) If I add the introduced `-l` flag we see below kernels: ![image](https://github.com/microsoft/onnxruntime/assets/44298237/fceb5d6f-c12d-442b-b15a-948797630008) Notice the missing NCHW<>NHWC kernels per operation. The layout optimizer introduced a transpose op as first and last op of the whole network. The `op_generic_tensor_kernel` shows the bias used which should also be optimized out next. Measured across some very basic models: | CUDA EP | **NCHW** [ms] | **NHWC** [ms] | Speedup | |:------------------------|--------------------------------------:|-----------------------------------------:|------------------:| | | -e cuda -t 5 -q | -e cuda -t 5 -q -l | | | resnet101-v2-7_bs8_fp16 | 18.33 | 13.07 | 1.4 | | resnet101-v2-7_bs8 | 21.8 | 12.06 | 1.81 | | test | 102.07 | 73.62 | 1.39 | Average speedup: 1.53 ## Outlook Next the mission will be to first write a templated unit test to check for correctness of NHWC vs NCHW ops. After that we have to transition more ops to measure perf improvements on a broader range of models. Currently this is not easily possible as we can do not support all ops in the NHWC domain. --------- Co-authored-by: Tianlei Wu <[email protected]>
### Description This PR: (1) Fixes AMD builds after #17200 broke them (Need to remember to run AMD builds while trying to merge external CUDA PRs next time) (2) Turn on the NHWC CUDA feature in the Linux GPU CI. The extra time spent in building a few more files and running a few more tests will not be much. Test Linux GPU CI run : https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770 ### Motivation and Context Keep the NHWC CUDA ops tested (#17200) and guard against regressions
Oh sorry i did not update the description. The option for that will be '-i "prefer_nhwc|1". @YangQiangli |
### Description <!-- Describe your changes. --> - Treat Resize as layout sensitive by default - whilst the ONNX spec does not specify a layout, EPs tend to implement only one - add second usage in L2 of TransposeOptimizer to plugin the ability to push a Transpose through a Resize assigned to the CPU EP - Allow EP specific logic for changes the ops considered to be layout sensitive to be plugged in - expected usage is for microsoft#17200 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Finish simplifying/clarifying transpose optimization and layout transformation that was proposed in microsoft#15552. This PR along with microsoft#17618 should complete the changes. --------- Co-authored-by: Edward Chen <[email protected]>
### Description CUDA inference speed heavily relies on Tensor Cores. To have tensor cores achieve the optimal throughput they require the data layout to be NHWC rather than NCHW. ### Motivation and Context Especially for convolutional networks this is very important. I will illustrate this using a very simple network: ``` import torch import torch.nn as nn class Net1(nn.Module): def __init__(self): super(Net1, self).__init__() # 1 input image channel, 6 output channels, 5x5 square convolution # kernel self.m = nn.ModuleList([ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, bias=False), ]) def forward(self, x): for module in self.m: x = module(x) return x if __name__ == "__main__": dtype = torch.half device = "cuda" dummy_input = torch.randn(8, 8, 512, 512, dtype=dtype, device=device) model = Net1().to(dtype=dtype, device=device) input_names = ["input1"] output_names = ["output1"] torch.onnx.export(model, dummy_input, "test.onnx", input_names=input_names, output_names=output_names) ``` I profiled the launch of `./build/RelWithDebInfo/onnxruntime_perf_test -e cuda -I -q -t 5 test.onnx` using sys and nvtx ranges. Current master launches below kernels: ![image](https://github.com/microsoft/onnxruntime/assets/44298237/81655fce-0f8e-4f78-9335-b858a8c8977b) If I add the introduced `-l` flag we see below kernels: ![image](https://github.com/microsoft/onnxruntime/assets/44298237/fceb5d6f-c12d-442b-b15a-948797630008) Notice the missing NCHW<>NHWC kernels per operation. The layout optimizer introduced a transpose op as first and last op of the whole network. The `op_generic_tensor_kernel` shows the bias used which should also be optimized out next. Measured across some very basic models: | CUDA EP | **NCHW** [ms] | **NHWC** [ms] | Speedup | |:------------------------|--------------------------------------:|-----------------------------------------:|------------------:| | | -e cuda -t 5 -q | -e cuda -t 5 -q -l | | | resnet101-v2-7_bs8_fp16 | 18.33 | 13.07 | 1.4 | | resnet101-v2-7_bs8 | 21.8 | 12.06 | 1.81 | | test | 102.07 | 73.62 | 1.39 | Average speedup: 1.53 ## Outlook Next the mission will be to first write a templated unit test to check for correctness of NHWC vs NCHW ops. After that we have to transition more ops to measure perf improvements on a broader range of models. Currently this is not easily possible as we can do not support all ops in the NHWC domain. --------- Co-authored-by: Tianlei Wu <[email protected]>
…oft#17972) ### Description This PR: (1) Fixes AMD builds after microsoft#17200 broke them (Need to remember to run AMD builds while trying to merge external CUDA PRs next time) (2) Turn on the NHWC CUDA feature in the Linux GPU CI. The extra time spent in building a few more files and running a few more tests will not be much. Test Linux GPU CI run : https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770 ### Motivation and Context Keep the NHWC CUDA ops tested (microsoft#17200) and guard against regressions
Description
CUDA inference speed heavily relies on Tensor Cores. To have tensor cores achieve the optimal throughput they require the data layout to be NHWC rather than NCHW.
Motivation and Context
Especially for convolutional networks this is very important. I will illustrate this using a very simple network:
I profiled the launch of
./build/RelWithDebInfo/onnxruntime_perf_test -e cuda -I -q -t 5 test.onnx
using sys and nvtx ranges.Current master launches below kernels:
If I add the introduced
-l
flag we see below kernels:Notice the missing NCHW<>NHWC kernels per operation. The layout optimizer introduced a transpose op as first and last op of the whole network. The
op_generic_tensor_kernel
shows the bias used which should also be optimized out next.Measured across some very basic models:
Outlook
Next the mission will be to first write a templated unit test to check for correctness of NHWC vs NCHW ops. After that we have to transition more ops to measure perf improvements on a broader range of models. Currently this is not easily possible as we can do not support all ops in the NHWC domain.