Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[js/webgpu] Optimize InstanceNormalization (#21995)
### Description <!-- Describe your changes. --> For InstanceNormalization, it has `y = scale * (x - mean) / sqrt(variance + epsilon) + B` , where mean and variance are computed per instance per channel. Calculating mean and variance per channel is a reduce processing, which is NCHW layout friendly since it makes the adjacent threads can access contiguous data in gpu memory. This PR optimizes both NHWC and NCHW InstanceNormalization. To efficiently calculate the mean and variance, we need to make sure the input is NCHW instead of NHWC. Then use shared memory to do the reduce operation to get `channel_scale` and `channel_shift`. With this PR, getting `channel_scale` and `channel_shift` are same for NHWC and NCHW InstanceNormalization. And the overall performance becomes very close now. Below data comes from SD Turbo profiling results. Before (InstanceNormalization overall time: 140.84 ms) InstanceNormalization\|InstanceNormComputeMean | 129.70 -- | -- InstanceNormalization\|InstanceNormalizationNHWC | 10.55 InstanceNormalization\|InstanceNormComputeChannelScaleShift | 0.59 After (InstanceNormalization overall time: 59.44 ms) InstanceNormalization\|InstanceNormComputeChannelScaleShift | 28.57 -- | -- InstanceNormalization\|TransposeShared | 20.19 InstanceNormalization\|InstanceNormalizationNHWC | 10.68
- Loading branch information