Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[js/webgpu] Optimize NCHW layout for InstanceNormalization (#18123)
### Description The changes in this PR includes: 1) Fix f16 errors in InstanceNormalization with NCHW format. 2) Use vec to further optimize the original algorithm. 3) (Removed) Don't do layout conversion for InstanceNormalization for JSEP since InstanceNormalization itself is suitable for NCHW layout and has better performance in our current implementation. Tested on sd-vae-decoder-f16.onnx, it becomes 285 ms from 314 ms. The aggregate gpu profiling data can be found as below (Note the data is based change 3).): Before: <html> <body> <!--StartFragment--><span><span class="ui-provider ef bbg bbh bbi bbj bbk bbl bbm bbn bbo bbp bbq bbr bbs bbt bbu bbv bbw bbx bby bbz bca bcb bcc bcd bce bcf bcg bch bci bcj bck bcl bcm bcn" dir="ltr"> Kernel | Time (Ms) | Percentage (%) -- | -- | -- Conv | 201.55 | 69.56 InstanceNormalization | 42.49 | 14.67 Transpose | 28.95 | 9.99 Mul | 5.69 | 1.96 Add | 3.82 | 1.32 MatMul | 3.27 | 1.13 Sigmoid | 2.24 | 0.77 Resize | 1.16 | 0.40 Softmax | 0.34 | 0.12 Cast | 0.24 | 0.08 Sum | 289.75 <br class="Apple-interchange-newline"><!--EndFragment--> </body> </html> After: <html> <body> <!--StartFragment--><span><span class="ui-provider ef bbg bbh bbi bbj bbk bbl bbm bbn bbo bbp bbq bbr bbs bbt bbu bbv bbw bbx bby bbz bca bcb bcc bcd bce bcf bcg bch bci bcj bck bcl bcm bcn" dir="ltr"> Kernel | Time (Ms) | Percentage (%) -- | -- | -- Conv | 205.44 | 79.43 InstanceNormalization | 18.24 | 7.05 Transpose | 17.64 | 6.82 Mul | 5.69 | 2.20 Add | 3.81 | 1.47 MatMul | 3.56 | 1.38 Sigmoid | 2.24 | 0.86 Resize | 1.19 | 0.46 Softmax | 0.59 | 0.23 Cast | 0.24 | 0.09 Sum | 258.65 | </span></span><!--EndFragment--> </body> </html> From above table, we can see that two ops time are greatly reduced. One is InstanceNormalization and the other is Transpose. The reason that the transpose time is reduced is because each InstanceNormalization is surrounded with two reshape ops in sd-vae-decoder-f16.onnx. Due to JSEP is prefer NHWC and InstanceNormalization is layout sensitive op, so two extra transpose ops are inserted dynamically when executing this model. After this change, those inserted transpose ops are not needed anymore. So the overall transpose time is reduced.
- Loading branch information