forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[js/webgpu] Optimize InstanceNormalization (microsoft#17491)
### Description <!-- Describe your changes. --> In previous implementation, there are two loops to iterate H * W elements to calculate the `mean` and `squaredNorm` value in one thread, meanwhile it outputs H * W elements in one thread. That results it's very very slow when H * W is a large value. And usually, H * W does be a large value in a model. For example, in the `candy-8` model, the shapes of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization` op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below: ``` [profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns ``` In this PR, it uses workgroup memory to optimize the original algorithm. The advantage is that it can parallelly utilize the 64 (workgroupSize) threads in one workgroup to calculate `mean` and `squaredNorm` value. Meanwhile, it only outputs `H * W / workgroupSize` outputs for one thread, which greatly reduces the overhead for one thread. With this optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the extra two `transpose`. The `createInstanceNormProgramInfo` only needs `0.64` ms. See below: ``` [profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns program-manager.ts:115 [profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns program-manager.ts:115 [profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns ``` This PR currently only applies the new algorithm to NCHW format. For NHWC format, one way is to transpose the input so that it can use the new algorithm. But the disadvantage is that 2 extra transpose are added. @dakenf also gives another way to optimize NHWC. Details see [here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts). I checked @dakenf's method. The perf is similar with transpose + optimized NCHW. But on different GPUs, one is a little better than another or vice versa. So I prefer this PR only does the NCHW part. @dakenf can submit his optimization on NHWC.
- Loading branch information
Showing
4 changed files
with
147 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
[ | ||
{ | ||
"name": "Simple test with NHWC", | ||
"operator": "InstanceNormalization", | ||
"inputShapeDefinitions": "rankOnly", | ||
"opset": { "domain": "", "version": 17 }, | ||
"cases": [ | ||
{ | ||
"name": "Simple test", | ||
"inputs": [ | ||
{ | ||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4], | ||
"dims": [1, 4, 2, 2], | ||
"type": "float32" | ||
}, | ||
{ | ||
"data": [1, 2, 3, 4], | ||
"dims": [4], | ||
"type": "float32" | ||
}, | ||
{ | ||
"data": [4, 5, 6, 7], | ||
"dims": [4], | ||
"type": "float32" | ||
} | ||
], | ||
"outputs": [ | ||
{ | ||
"data": [ | ||
2.6583645343780518, 3.552788257598877, 4.447211742401123, 5.341635704040527, 2.3167295455932617, | ||
4.105576515197754, 5.8944244384765625, 7.683271408081055, 6, 10.242595672607422, 6, 1.7574005126953125, | ||
12.36654281616211, 8.788846969604492, 5.211153030395508, 1.633458137512207 | ||
], | ||
"dims": [1, 4, 2, 2], | ||
"type": "float32" | ||
} | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"name": "Simple test with NCHW", | ||
"operator": "InstanceNormalization", | ||
"opset": { "domain": "", "version": 17 }, | ||
"cases": [ | ||
{ | ||
"name": "Simple test", | ||
"inputs": [ | ||
{ | ||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4], | ||
"dims": [1, 4, 2, 2], | ||
"type": "float32" | ||
}, | ||
{ | ||
"data": [1, 2, 3, 4], | ||
"dims": [4], | ||
"type": "float32" | ||
}, | ||
{ | ||
"data": [4, 5, 6, 7], | ||
"dims": [4], | ||
"type": "float32" | ||
} | ||
], | ||
"outputs": [ | ||
{ | ||
"data": [ | ||
2.6583645343780518, 3.552788257598877, 4.447211742401123, 5.341635704040527, 2.3167295455932617, | ||
4.105576515197754, 5.8944244384765625, 7.683271408081055, 6, 10.242595672607422, 6, 1.7574005126953125, | ||
12.36654281616211, 8.788846969604492, 5.211153030395508, 1.633458137512207 | ||
], | ||
"dims": [1, 4, 2, 2], | ||
"type": "float32" | ||
} | ||
] | ||
} | ||
] | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters