diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index ec9da2613f406..00a6ca75b34fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -108,7 +108,7 @@ const createBatchNormInferenceProgramInfo = let inputMean = ${inputMean.getByOffset('cOffset')}; let inputVar = ${inputVar.getByOffset('cOffset')}; let x = ${x.getByOffset('global_idx')}; - let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias; + let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias; ${y.setByOffset('global_idx', 'value')} }`; return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 056dd54d54591..a835c90bd5451 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -92,7 +92,7 @@ const createInstanceNormProgramInfo = } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); + let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { @@ -212,7 +212,7 @@ const computeMean = } sum = sum / f32(uniforms.H); squaredSum = squaredSum / f32(uniforms.H); - let invStdDev = 1 / sqrt(squaredSum - sum * sum + f32(${epsilon})); + let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon})); let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index bc446079faf8f..3c9f6ce71bb67 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -93,19 +93,19 @@ const createLayerNormProgramInfo = meanSquareVector += value * value; } let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} - / uniforms.norm_size - mean * mean + uniforms.epsilon); + let invStdDev = + inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = 1 / meanSquare' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 7e500f865c19b..a2fda9f07d09f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -132,11 +132,11 @@ const createSkipLayerNormProgramInfo = squareSum += f32Value * f32Value; } let mean = ${sumVector('sum', components)} / hiddenSize; - let variance = sqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); + let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} - ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} + ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''} for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i] + output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`;