Skip to content

Commit

Permalink
[js/webgpu] Optimize InstanceNormalization (#21995)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
For InstanceNormalization, it has `y = scale * (x - mean) /
sqrt(variance + epsilon) + B` , where mean and variance are computed per
instance per channel. Calculating mean and variance per channel is a
reduce processing, which is NCHW layout friendly since it makes the
adjacent threads can access contiguous data in gpu memory.

This PR optimizes both NHWC and NCHW InstanceNormalization. To
efficiently calculate the mean and variance, we need to make sure the
input is NCHW instead of NHWC. Then use shared memory to do the reduce
operation to get `channel_scale` and `channel_shift`.

With this PR, getting `channel_scale` and `channel_shift` are same for
NHWC and NCHW InstanceNormalization. And the overall performance becomes
very close now.

Below data comes from SD Turbo profiling results.
Before (InstanceNormalization overall time: 140.84 ms)

InstanceNormalization\|InstanceNormComputeMean | 129.70
-- | -- 
InstanceNormalization\|InstanceNormalizationNHWC | 10.55
InstanceNormalization\|InstanceNormComputeChannelScaleShift | 0.59


After (InstanceNormalization overall time:  59.44 ms)

InstanceNormalization\|InstanceNormComputeChannelScaleShift | 28.57
-- | -- 
InstanceNormalization\|TransposeShared | 20.19
InstanceNormalization\|InstanceNormalizationNHWC | 10.68
  • Loading branch information
qjia7 authored Sep 23, 2024
1 parent 9b37b3e commit 80e9df8
Showing 1 changed file with 154 additions and 215 deletions.
Loading

0 comments on commit 80e9df8

Please sign in to comment.