Skip to content

Commit

Permalink
[js/webgpu] Support uniforms for layer-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 26, 2023
1 parent 37f7436 commit b763ed5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
import {pad, parsePadAttributes} from './ops/pad';
Expand Down Expand Up @@ -83,7 +83,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['Less', [binaryOps.less]],
['LessOrEqual', [binaryOps.lessOrEqual]],
Expand Down
81 changes: 48 additions & 33 deletions js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common';
import {castToF32, createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType,} from './common';

export interface LayerNormAttributes extends AttributeWithCacheKey {
interface LayerNormAttributes {
axis: number;
epsilon: number;
}
Expand Down Expand Up @@ -47,60 +46,78 @@ const createLayerNormProgramInfo =
meanInvStdDevDim.push(1);
}
}

const components = getMaxComponents(normSize);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('scale', scale.dataType, scale.dims, components),
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: normCount}, {type: 'float32', data: normSize},
{type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon}
];
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(scale.dims));
if (bias) {
variables.push(inputVariable('bias', bias.dataType, bias.dims, components));
programUniforms.push(...createTensorShapeVariables(bias.dims));
inputDependencies.push('rank');
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
programUniforms.push(...createTensorShapeVariables(outputShape));

const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;

if (hasMeanDataOutput) {
variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim));
programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim));
}
if (hasInvStdOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim));
}

const getShaderSource = (shaderHelper: ShaderHelper) => `
const normSize: f32 = ${normSize};
const normSizeVectorized: u32 = ${normSize / components};
const epsilon: f32 = ${attributes.epsilon};
const getShaderSource = (shaderHelper: ShaderHelper) => {
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components),
inputVariable('scale', scale.dataType, scale.dims.length, components),
];
if (bias) {
variables.push(inputVariable('bias', bias.dataType, bias.dims.length, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape.length, components));
if (hasMeanDataOutput) {
variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim.length));
}
if (hasInvStdOutput) {
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim.length));
}

${shaderHelper.declareVariables(...variables)}
const uniforms: UniformsArrayType = [
{name: 'norm_count', type: 'u32'}, {name: 'norm_size', type: 'f32'}, {name: 'norm_size_vectorized', type: 'u32'},
{name: 'epsilon', type: 'f32'}
];
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)}
let offset = global_idx * normSizeVectorized;
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')}
let offset = global_idx * uniforms.norm_size_vectorized;
var meanVector = ${fillVector('f32', components)};
var meanSquareVector = ${fillVector('f32', components)};
for (var h: u32 = 0u; h < normSizeVectorized; h++) {
for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {
let value = ${castToF32(dataType, components, 'x[h + offset]')};
meanVector += value;
meanSquareVector += value * value;
}
let mean = ${sumVector('meanVector', components)} / normSize;
let meanSquare = sqrt(${sumVector('meanSquareVector', components)}
/ normSize - mean * mean + epsilon);
let mean = ${sumVector('meanVector', components)} / uniforms.norm_size;
let meanSquare = sqrt(${sumVector('meanSquareVector', components)}
/ uniforms.norm_size - mean * mean + uniforms.epsilon);
for (var j: u32 = 0; j < normSizeVectorized; j++) {
for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
let f32input = ${castToF32(dataType, components, 'x[j + offset]')};
let f32scale = ${castToF32(dataType, components, 'scale[j]')};
output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale
${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''}
);
}
${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''};
${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''};
${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''};
${hasInvStdOutput ? 'inv_std_output[global_idx] = 1 / meanSquare' : ''};
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (hasMeanDataOutput) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
Expand All @@ -111,15 +128,13 @@ const createLayerNormProgramInfo =

return {
name: 'LayerNormalization',
shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`},
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}),
shaderCache: {hint: `${components};${hasMeanDataOutput};${hasInvStdOutput}`, inputDependencies},
getRunData: () =>
({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}, programUniforms}),
getShaderSource,
};
};

export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes =>
createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon});

export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => {
validateInputs(context.inputs);
context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount));
Expand Down

0 comments on commit b763ed5

Please sign in to comment.