Skip to content

Commit

Permalink
[js/webgpu] Optimize InstanceNorm in some shapes (microsoft#22637)
Browse files Browse the repository at this point in the history
BUG microsoft#22031

Optimize below two situations:
1. Increase workgroupSize if only one workgroup is dispatched.
2. Avoid transpose if not necessary.

The overall time of demucs model becomes 106.36 ms from 154.60 ms on my
dGPUs with this PR and PR microsoft#22577
  • Loading branch information
qjia7 authored and ankitm3k committed Dec 11, 2024
1 parent f29b77f commit fd8e811
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ const computeChannelScaleShift = (
const f32Type = components === 1 ? 'f32' : `vec${components}f`;
const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const unitsOfWork = n * c;

let workgroupSize = 64;
if (unitsOfWork === 1) {
workgroupSize = 256;
}
const inputShape = [n, c, h / components];
const outputShape = [n, c, 2];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
Expand All @@ -49,7 +52,6 @@ const computeChannelScaleShift = (
const b = inputVariable('bias', bias.dataType, bias.dims);
const output = outputVariable('output', DataType.float, 3, 2);
const variables = [x, s, b, output];
const workgroupSize = 64;
return `
var<workgroup> workgroup_shared : array<${wgType}, ${workgroupSize}>;
const workgroup_size = ${workgroupSize}u;
Expand Down Expand Up @@ -91,7 +93,7 @@ const computeChannelScaleShift = (
{
name: 'InstanceNormComputeChannelScaleShift',
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: { hint: `${components};${epsilon}`, inputDependencies },
shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: DataType.float }],
dispatchGroup: { x: unitsOfWork },
Expand Down Expand Up @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = (
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];

// 1. transpose x from NHWC to NCHW
let needTranspose = false;
const transposedXPerm = [0, xShape.length - 1];
for (let i = 0; i < xShape.length - 2; i++) {
needTranspose = needTranspose || xShape[i + 1] !== 1;
transposedXPerm.push(i + 1);
}
const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), {
inputs: [context.inputs[0]],
outputs: [-1],
})[0];

needTranspose = needTranspose && xShape[xShape.length - 1] !== 1;

const transposedX = needTranspose
? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), {
inputs: [context.inputs[0]],
outputs: [-1],
})[0]
: context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]]));
// 2. compute channel scale and channel shift.
const channelScaleShift = computeChannelScaleShift(
context,
Expand Down

0 comments on commit fd8e811

Please sign in to comment.