Skip to content

Commit

Permalink
[js/webgpu] Change A/sqrt(B) to A*inverseSqrt(B) in normalization ops (
Browse files Browse the repository at this point in the history
…#19101)

### Description
Change `A / sqrt(B)` to `A * inverseSqrt(B)` in BatchNormalization,
InstanceNormalization, LayerNormalization and SkipLayerNormalization.

### Motivation and Context
For the same reason as the existence of the `inverseSqrt` built-in in
WebGPU spec.
  • Loading branch information
hujiajie authored Jan 12, 2024
1 parent 5373c0c commit acba63c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ const createBatchNormInferenceProgramInfo =
let inputMean = ${inputMean.getByOffset('cOffset')};
let inputVar = ${inputVar.getByOffset('cOffset')};
let x = ${x.getByOffset('global_idx')};
let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias;
let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias;
${y.setByOffset('global_idx', 'value')}
}`;
return {
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ const createInstanceNormProgramInfo =
}
workgroupBarrier();
let invStdDev = 1 / sqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
Expand Down Expand Up @@ -212,7 +212,7 @@ const computeMean =
}
sum = sum / f32(uniforms.H);
squaredSum = squaredSum / f32(uniforms.H);
let invStdDev = 1 / sqrt(squaredSum - sum * sum + f32(${epsilon}));
let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon}));
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;
Expand Down
8 changes: 4 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ const createLayerNormProgramInfo =
meanSquareVector += value * value;
}
let mean = ${sumVector('meanVector', components)} / uniforms.norm_size;
let meanSquare = sqrt(${sumVector('meanSquareVector', components)}
/ uniforms.norm_size - mean * mean + uniforms.epsilon);
let invStdDev =
inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon);
for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
let f32input = ${castToF32(dataType, components, 'x[j + offset]')};
let f32scale = ${castToF32(dataType, components, 'scale[j]')};
output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale
output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale
${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''}
);
}
${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''};
${hasInvStdOutput ? 'inv_std_output[global_idx] = 1 / meanSquare' : ''};
${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''};
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ const createSkipLayerNormProgramInfo =
squareSum += f32Value * f32Value;
}
let mean = ${sumVector('sum', components)} / hiddenSize;
let variance = sqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon);
let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon);
${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''}
${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''}
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i]
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i]
+ ${hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
Expand Down

0 comments on commit acba63c

Please sign in to comment.