diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 40a92f9e0fd69..7e0594fe02662 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -49,8 +49,7 @@ const createLayerNormProgramInfo = } } - // TODO: for some reason it does not work correctly with fp16 - const components = inputs[0].dataType !== DataType.float16 ? getMaxComponents(normSize) : 1; + const components = getMaxComponents(normSize); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const variables = [ inputVariable('x', inputs[0].dataType, inputs[0].dims, components), @@ -72,26 +71,27 @@ const createLayerNormProgramInfo = } const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: u32 = ${normSize / components}; + const normSize: f32 = ${normSize}; + const normSizeVectorized: u32 = ${normSize / components}; const epsilon: f32 = ${attributes.epsilon}; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} - let offset = global_idx * normSize; + let offset = global_idx * normSizeVectorized; var meanVector = ${fillVector('f32', components)}; var meanSquareVector = ${fillVector('f32', components)}; - for (var h: u32 = 0u; h < normSize; h++) { + for (var h: u32 = 0u; h < normSizeVectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; meanVector += value; meanSquareVector += value * value; } - let mean = ${sumVector('meanVector', components)} / f32(normSize); + let mean = ${sumVector('meanVector', components)} / normSize; let meanSquare = sqrt(${sumVector('meanSquareVector', components)} - / f32(normSize) - mean * mean + epsilon); + / normSize - mean * mean + epsilon); - for (var j: u32 = 0; j < normSize; j++) { + for (var j: u32 = 0; j < normSizeVectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 268ce5307eb74..464601fe415d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -109,7 +109,7 @@ const createSkipLayerNormProgramInfo = } const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: u32 = ${hiddenSize}; + const hiddenSize: f32 = ${hiddenSize}; const hiddenSizeVectorized: u32 = ${hiddenSize / components}; const epsilon: f32 = ${attributes.epsilon}; @@ -131,8 +131,8 @@ const createSkipLayerNormProgramInfo = sum += f32Value; squareSum += f32Value * f32Value; } - let mean = ${sumVector('sum', components)} / f32(hiddenSize); - let variance = sqrt(${sumVector('squareSum', components)} / f32(hiddenSize) - mean * mean + epsilon); + let mean = ${sumVector('sum', components)} / hiddenSize; + let variance = sqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {