From 8f024b739439c521cd19fe5a9830d4015de99bd7 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 9 Jan 2024 10:16:25 +0800 Subject: [PATCH] [js/webgpu] Support uniforms for layer-norm (#18755) --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 83 ++++++++++--------- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 8e1ec782079be..06c3c6c196501 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -17,7 +17,7 @@ import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; -import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; +import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; @@ -83,7 +83,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], - ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], + ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], ['LessOrEqual', [binaryOps.lessOrEqual]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 8a9eeecf2c68d..bc446079faf8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -4,12 +4,11 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType,} from './common'; -export interface LayerNormAttributes extends AttributeWithCacheKey { +interface LayerNormAttributes { axis: number; epsilon: number; } @@ -39,7 +38,7 @@ const createLayerNormProgramInfo = Got scale size of ${scaleSize} and bias size of ${biasSize}`); } - const meanInvStdDevDim = []; + const meanInvStdDevDim: number[] = []; for (let i = 0; i < xShape.length; ++i) { if (i < axis) { meanInvStdDevDim.push(xShape[i]); @@ -47,50 +46,57 @@ const createLayerNormProgramInfo = meanInvStdDevDim.push(1); } } - const components = getMaxComponents(normSize); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('scale', scale.dataType, scale.dims, components), + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, + {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} ]; if (bias) { - variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + inputDependencies.push('type'); } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; - if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: f32 = ${normSize}; - const normSizeVectorized: u32 = ${normSize / components}; - const epsilon: f32 = ${attributes.epsilon}; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanDataOutput) { + variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } - ${shaderHelper.declareVariables(...variables)} + const uniforms: UniformsArrayType = [ + {name: 'norm_count', type: 'u32'}, {name: 'norm_size', type: 'f32'}, + {name: 'norm_size_vectorized', type: 'u32'}, {name: 'epsilon', type: 'f32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} - let offset = global_idx * normSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} + let offset = global_idx * uniforms.norm_size_vectorized; var meanVector = ${fillVector('f32', components)}; var meanSquareVector = ${fillVector('f32', components)}; - for (var h: u32 = 0u; h < normSizeVectorized; h++) { + for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; meanVector += value; meanSquareVector += value * value; } - let mean = ${sumVector('meanVector', components)} / normSize; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} - / normSize - mean * mean + epsilon); + let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / uniforms.norm_size - mean * mean + uniforms.epsilon); - for (var j: u32 = 0; j < normSizeVectorized; j++) { + for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale @@ -98,9 +104,10 @@ const createLayerNormProgramInfo = ); } - ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''}; + ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = 1 / meanSquare' : ''}; }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (hasMeanDataOutput) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -111,15 +118,13 @@ const createLayerNormProgramInfo = return { name: 'LayerNormalization', - shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`}, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}), + shaderCache: {hint: `${components};${outputCount}`, inputDependencies}, + getRunData: () => + ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}, programUniforms}), getShaderSource, }; }; -export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes => - createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon}); - export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => { validateInputs(context.inputs); context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount));