diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 201c9d4b209db..616738b0e5b1e 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -17,7 +17,7 @@ import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; -import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; +import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; @@ -83,7 +83,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], - ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], + ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], ['LessOrEqual', [binaryOps.lessOrEqual]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 8a9eeecf2c68d..eeaf4d2082979 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -4,12 +4,11 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType,} from './common'; -export interface LayerNormAttributes extends AttributeWithCacheKey { +interface LayerNormAttributes { axis: number; epsilon: number; } @@ -51,46 +50,58 @@ const createLayerNormProgramInfo = const components = getMaxComponents(normSize); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('scale', scale.dataType, scale.dims, components), + inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components), + inputVariable('scale', scale.dataType, scale.dims.length, components), ]; + + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, + {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} + ]; + const uniforms: UniformsArrayType = [ + {name: 'normCount', type: 'u32'}, {name: 'normSize', type: 'f32'}, {name: 'normSizeVectorized', type: 'u32'}, + {name: 'epsilon', type: 'f32'} + ]; + programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(scale.dims)); if (bias) { - variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + variables.push(inputVariable('bias', bias.dataType, bias.dims.length, components)); + programUniforms.push(...createTensorShapeVariables(bias.dims)); + inputDependencies.push('rank'); } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + variables.push(outputVariable('output', inputs[0].dataType, outputShape.length, components)); + programUniforms.push(...createTensorShapeVariables(outputShape)); const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); + variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim.length)); + programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim)); } if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim.length)); + programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim)); } const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: f32 = ${normSize}; - const normSizeVectorized: u32 = ${normSize / components}; - const epsilon: f32 = ${attributes.epsilon}; - - ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} - let offset = global_idx * normSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.normCount')} + let offset = global_idx * uniforms.normSizeVectorized; var meanVector = ${fillVector('f32', components)}; var meanSquareVector = ${fillVector('f32', components)}; - for (var h: u32 = 0u; h < normSizeVectorized; h++) { + for (var h: u32 = 0u; h < uniforms.normSizeVectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; meanVector += value; meanSquareVector += value * value; } - let mean = ${sumVector('meanVector', components)} / normSize; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} - / normSize - mean * mean + epsilon); + let mean = ${sumVector('meanVector', components)} / uniforms.normSize; + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / uniforms.normSize - mean * mean + uniforms.epsilon); - for (var j: u32 = 0; j < normSizeVectorized; j++) { + for (var j: u32 = 0; j < uniforms.normSizeVectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale @@ -111,15 +122,13 @@ const createLayerNormProgramInfo = return { name: 'LayerNormalization', - shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`}, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}), + shaderCache: {hint: `${components};${hasMeanDataOutput};${hasInvStdOutput}`, inputDependencies}, + getRunData: () => + ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}, programUniforms}), getShaderSource, }; }; -export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes => - createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon}); - export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => { validateInputs(context.inputs); context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount));