Skip to content

Commit

Permalink
[js/webgpu] Optimize InstanceNorm in some shapes (microsoft#22637)
Browse files Browse the repository at this point in the history
BUG microsoft#22031

Optimize below two situations:
1. Increase workgroupSize if only one workgroup is dispatched.
2. Avoid transpose if not necessary.

The overall time of demucs model becomes 106.36 ms from 154.60 ms on my
dGPUs with this PR and PR microsoft#22577
  • Loading branch information
qjia7 authored and Ishwar Raut committed Nov 19, 2024
1 parent 927b6b7 commit 9137e2c
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ const computeChannelScaleShift = (
const f32Type = components === 1 ? 'f32' : `vec${components}f`;
const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const unitsOfWork = n * c;

let workgroupSize = 64;
if (unitsOfWork === 1) {
workgroupSize = 256;
}
const inputShape = [n, c, h / components];
const outputShape = [n, c, 2];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
Expand All @@ -49,7 +52,6 @@ const computeChannelScaleShift = (
const b = inputVariable('bias', bias.dataType, bias.dims);
const output = outputVariable('output', DataType.float, 3, 2);
const variables = [x, s, b, output];
const workgroupSize = 64;
return `
var<workgroup> workgroup_shared : array<${wgType}, ${workgroupSize}>;
const workgroup_size = ${workgroupSize}u;
Expand Down Expand Up @@ -91,7 +93,7 @@ const computeChannelScaleShift = (
{
name: 'InstanceNormComputeChannelScaleShift',
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: { hint: `${components};${epsilon}`, inputDependencies },
shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: DataType.float }],
dispatchGroup: { x: unitsOfWork },
Expand Down Expand Up @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = (
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];

// 1. transpose x from NHWC to NCHW
let needTranspose = false;
const transposedXPerm = [0, xShape.length - 1];
for (let i = 0; i < xShape.length - 2; i++) {
needTranspose = needTranspose || xShape[i + 1] !== 1;
transposedXPerm.push(i + 1);
}
const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), {
inputs: [context.inputs[0]],
outputs: [-1],
})[0];

needTranspose = needTranspose && xShape[xShape.length - 1] !== 1;

const transposedX = needTranspose
? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), {
inputs: [context.inputs[0]],
outputs: [-1],
})[0]
: context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]]));
// 2. compute channel scale and channel shift.
const channelScaleShift = computeChannelScaleShift(
context,
Expand Down

0 comments on commit 9137e2c

Please sign in to comment.