Skip to content

Commit

Permalink
FP16 LayerNorm, InstanceNorm, SkipLayerNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 20, 2023
1 parent e6301ee commit f6e7335
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 137 deletions.
46 changes: 46 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,52 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
return typeof mappedType === 'string' ? mappedType : mappedType[1];
};

/**
* A helper function to get maximum vector size for specified data length
* @param size
*/
export const getMaxComponents = (size: number) => {
// we cannot use vec3 type since it has alignment of 16 bytes
if (size % 4 === 0) {
return 4;
} else if (size % 2 === 0) {
return 2;
}

return 1;
};

/**
* A helper function that initializes variable as a scalar or vector. e.g. f32(0) or vec4f(0,0,0,0)
* @param dataType
* @param components
* @param value
*/
export const fillVector = (dataType = 'f32', components?: number, value = '0') => {
if (!components || components === 1) {
return `${dataType}(${value})`;
}

return `vec${components}<${dataType}>(${new Array(components).fill(value).join(',')})`;
};

/**
* A helper function that returns scalar or sums all components of a vector
* @param name
* @param components
*/
export const sumVector = (name: string, components: number) => {
if (components === 4) {
return `(${name}.x + ${name}.y + ${name}.z + ${name}.w)`;
} else if (components === 2) {
return `(${name}.x + ${name}.y)`;
} else if (components === 3) {
return `(${name}.x + ${name}.y + ${name}.z)`;
}

return name;
};

/**
* A helper function to get a IndicesHelper for a given input or output.
*
Expand Down
185 changes: 140 additions & 45 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
Expand Down Expand Up @@ -104,66 +105,160 @@ const createInstanceNormProgramInfo =
};
};

const computeMean =
(context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number,
epsilon: number) => {
const components = getMaxComponents(c);
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);

const WG = 64;
// we will store channel scale and channel shift in [2, components] matrix
// or in vec2 when components == 1
const outputType = components === 1 ? `vec2f` : `mat2x${components}f`;
const sumCastType = components === 1 ? `f32` : `vec${components}f`;
const setOutputValue = (var1: string, var2: string) => {
return `${outputType}(${var1}, ${var2})`;
};
const unitsOfWork = n * c / components;
const wgSize = Math.ceil(h / WG);

const getMeanShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
const imageSize: u32 = ${h * c / components};
${shaderHelper.declareVariables(inputHelper)}
@group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;
${shaderHelper.mainStart(WG)}
let currentImageNumber = global_idx / ${WG} / C;
let currentChannelNumber = (global_idx / ${WG}) % C;
let wgId = global_idx % ${WG};
let wgOffset = wgId * ${wgSize};
if (wgOffset >= H) {
return;
}
let wgMax = min(wgOffset + ${wgSize}, H);
let offset = currentImageNumber * imageSize + currentChannelNumber;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = wgOffset; i < wgMax; i++) {
let value = ${sumCastType}(input[offset + i * C]);
sum += value;
squaredSum += value * value;
}
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
}`;

const meanValues = context.compute(
{
name: 'InstanceNormComputeMean',
inputTypes: [GpuDataType.default],
cacheHint: JSON.stringify({components, n, h, c}),
outputs: [
{dims: [n, c, WG, 2], dataType: DataType.float, gpuDataType: GpuDataType.default},
],
getShaderSource: getMeanShaderSource,
dispatchGroup: () => ({x: n * c / components})
},
{inputs: [input], outputs: [-1]})[0];
const getShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
const imageSize: u32 = ${WG * c / components};
const epsilon: f32 = ${epsilon};
@group(0) @binding(0) var<storage, read> input : array<${outputType}>;
@group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
@group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
@group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)}
let currentImageNumber = global_idx / C;
let currentChannelNumber = global_idx % C;
let offset = currentImageNumber * imageSize;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < ${WG}; i++) {
let value = input[offset + i + currentChannelNumber * ${WG}];
sum += value[0];
squaredSum += value[1];
}
sum = sum / f32(H);
squaredSum = squaredSum / f32(H);
let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon);
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;
output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
}`;

return context.compute(
{
name: 'InstanceNormComputeChannelScaleShift',
inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default],
cacheHint: JSON.stringify({components, n, h, c, epsilon}),
outputs: [
{dims: [n, c, 2], dataType: DataType.float, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)})
},
{inputs: [meanValues, scale, bias], outputs: [-1]})[0];
};

const createInstanceNormNHWCProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
(context: ComputeContext, metadata: ProgramMetadata, inputs: readonly TensorView[],
attributes: InstanceNormAttributes) => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const N = xShape[0];
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;

const components = getMaxComponents(C);
const outputSize = ShapeUtil.size(outputShape) / components;
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const scaleType = components === 1 ? `vec2f` : `mat2x${components}f`;
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
// first compute mean
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);

const normCount = C * N;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const H: u32 = ${H};
const C: u32 = ${C};
const normSizeTyped: ${dataType} = ${H};
const imageSize: u32 = ${H * C};
const epsilon: f32 = ${attributes.epsilon};
const C: u32 = ${C / components};
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read> scale : array<${dataType}>;
@group(0) @binding(2) var<storage, read> bias : array<${dataType}>;
@group(0) @binding(3) var<storage, read_write> output : array<${dataType}>;
@group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
@group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
@group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;
${shaderHelper.mainStart()}
let currentImageNumber = global_idx / C;
let currentImageNumber = global_idx / (C * H);
let currentChannelNumber = global_idx % C;
// offset is channel num * N
let offset = currentImageNumber * imageSize;
if (offset >= ${outputSize}) { return; }
var mean: ${dataType} = 0;
for (var i: u32 = 0u; i < H; i++) {
mean = mean + x[offset + i * C + currentChannelNumber];
}
mean = mean / normSizeTyped;
var squaredNorm: ${dataType} = 0;
for (var i: u32 = 0u; i < H; i++) {
let deviation: f32 = x[offset + i * C + currentChannelNumber] - mean;
squaredNorm = squaredNorm + deviation * deviation;
}
let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon);
let channelScale = invStdDev * scale[currentChannelNumber];
let channelShift = bias[currentChannelNumber] - mean * channelScale;
for (var i: u32 = 0u; i < H; i++) {
let currentOffset = offset + i * C + currentChannelNumber;
output[currentOffset] = x[currentOffset] * channelScale + channelShift;
}
let scaleOffset = currentImageNumber * C + currentChannelNumber;
let scale = scaleInput[scaleOffset];
output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)})
};
context.compute(
{
...metadata,
inputTypes: [GpuDataType.default, GpuDataType.default],
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
},
{inputs: [inputs[0], channelScaleShift]});
};

export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes =>
Expand All @@ -177,7 +272,7 @@ export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAt
};

if (attributes.format === 'NHWC') {
context.compute(createInstanceNormNHWCProgramInfo(metadata, context.inputs, attributes));
createInstanceNormNHWCProgramInfo(context, metadata, context.inputs, attributes);
} else {
context.compute(createInstanceNormProgramInfo(metadata, context.inputs, attributes));
}
Expand Down
58 changes: 28 additions & 30 deletions js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';

export interface LayerNormAttributes extends AttributeWithCacheKey {
axis: number;
Expand All @@ -18,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 2) {
throw new Error('layerNorm requires at least 2 inputs.');
}

if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
throw new Error('inputs should be float type');
}
};

const createLayerNormProgramInfo =
Expand All @@ -32,7 +27,6 @@ const createLayerNormProgramInfo =
const bias = inputs[2];

const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length);
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
Expand All @@ -55,40 +49,44 @@ const createLayerNormProgramInfo =
}

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const components = getMaxComponents(normSize);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('scale', scale.dataType, scale.dims, components),
];
if (bias) {
variables.push(inputVariable('bias', bias.dataType, bias.dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));

const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;
let bindingIndex = 0;

if (hasMeanDataOutput) {
variables.push(outputVariable('meanDataOutput', inputs[0].dataType, meanInvStdDevDim));
}
if (hasInvStdOutput) {
variables.push(outputVariable('invStdOutput', inputs[0].dataType, meanInvStdDevDim));
}

const getShaderSource = (shaderHelper: ShaderHelper) => `
const normSize: u32 = ${normSize};
const normSize: u32 = ${normSize / components};
const normSizeTyped: ${dataType} = ${normSize};
const epsilon: f32 = ${attributes.epsilon};
@group(0) @binding(${bindingIndex++}) var<storage, read> x : array<${dataType}>;
@group(0) @binding(${bindingIndex++}) var<storage, read> scale : array<${dataType}>;
${bias ? `@group(0) @binding(${bindingIndex++}) var<storage, read> bias : array<${dataType}>;` : ''}
@group(0) @binding(${bindingIndex++}) var<storage, read_write> output : array<${dataType}>;
${
hasMeanDataOutput ?
`@group(0) @binding(${bindingIndex++}) var<storage, read_write> meanDataOutput : array<${dataType}>` :
''};
${
hasInvStdOutput ?
`@group(0) @binding(${bindingIndex++}) var<storage, read_write> invStdOutput : array<${dataType}>` :
''};
const epsilon: ${dataType} = ${attributes.epsilon};
${shaderHelper.declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)}
let offset = global_idx * normSize;
if (offset >= ${outputSize}) { return; }
var mean: ${dataType} = 0;
var meanSquare: ${dataType} = 0;
var meanVector = ${fillVector(dataType, components)};
var meanSquareVector = ${fillVector(dataType, components)};
for (var h: u32 = 0u; h < normSize; h++) {
mean = mean + x[h + offset];
meanSquare = meanSquare + x[h + offset] * x[h + offset];
meanVector += x[h + offset];
meanSquareVector += x[h + offset] * x[h + offset];
}
mean = mean / normSizeTyped;
meanSquare = sqrt(meanSquare / normSizeTyped - mean * mean + epsilon);
let mean = ${sumVector('meanVector', components)} / normSizeTyped;
let meanSquare = sqrt(${sumVector('meanSquareVector', components)} / normSizeTyped - mean * mean + epsilon);
for (var j: u32 = 0; j < normSize; j++) {
output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''};
Expand Down
Loading

0 comments on commit f6e7335

Please sign in to comment.