Skip to content

Commit

Permalink
[js/webgpu] Add BatchNormalization Op (#18468)
Browse files Browse the repository at this point in the history
### Description
This PR adds `BatchNormalization` with `float` support.

Some Todos:
1. all inputs don't have same data type. For example, x/y is float16,
but bias/scale is float32 or double.
2. training mode support.

We see many models are using `BatchNormalization` ops. However, due to
the missing in jsep, all of them run on cpu, which result very poor
performance. With this PR's support, densenet-9 model becomes 20.29 ms
from 250.69 ms.
  • Loading branch information
qjia7 authored Nov 22, 2023
1 parent fa10694 commit 64dacc2
Show file tree
Hide file tree
Showing 9 changed files with 688 additions and 0 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Do not modify directly.*
| Atanh | ai.onnx(9+) | |
| Attention | com.microsoft(1+) | need implementing mask and past/present |
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation |
| BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | |
| BiasAdd | com.microsoft(1+) | |
| BiasSplitGelu | com.microsoft(1+) | |
| Cast | ai.onnx(6-8,9-12,13-18,19+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
import {attention, parseAttentionAttributes} from './ops/attention';
import {batchNorm} from './ops/batch-norm';
import {biasAdd} from './ops/bias-add';
import {biasSplitGelu} from './ops/bias-split-gelu';
import * as binaryOps from './ops/binary-op';
Expand Down Expand Up @@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Attention', [attention, parseAttentionAttributes]],
// TODO: support new attributes for AveragePool-10
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
['BatchNormalization', [batchNorm]],
['BiasAdd', [biasAdd]],
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
Expand Down
150 changes: 150 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from 'onnxruntime-common';

import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';

export interface BatchNormAttributes extends AttributeWithCacheKey {
readonly epsilon: number;
readonly momentum: number;
readonly spatial: boolean;
readonly trainingMode: boolean;
readonly format: 'NHWC'|'NCHW';
readonly outputCount: number;
}

const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => {
if (!inputs || inputs.length !== 5) {
throw new Error('BatchNormalization requires 5 inputs');
}

const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => {
const r = expected.length;
if (r !== actual.length) {
throw new Error(`${message}: num dimensions != ${r}`);
}
expected.forEach((v, i) => {
if (v !== actual[i]) {
throw new Error(`${message}: dim[${i}] do not match`);
}
});
};

if (inputs[0].dims.length > 1) {
const shape = attributes.format === 'NHWC' ?
(attributes.spatial ? inputs[0].dims.slice(-1) :
inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) :
inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined);
checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale');
checkShapeEqual(inputs[2].dims, shape, 'Invalid input B');
checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean');
checkShapeEqual(inputs[4].dims, shape, 'Invalid input var');
} else {
checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale');
checkShapeEqual(inputs[2].dims, [1], 'Invalid input B');
checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean');
checkShapeEqual(inputs[4].dims, [1], 'Invalid input var');
}
};

const createBatchNormInferenceProgramInfo =
(inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => {
const {epsilon, spatial, format} = attributes;
const yShape = inputs[0].dims;
const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1;
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
const outputSize = ShapeUtil.size(yShape) / components;
// Only support uniforms for opset version >= 9 (spatial = true).
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial;
const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents);
const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents);
const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents);
const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components);
// TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type.
// Otherwise, the shader compilation will fail.
const calcCOffset = (): string => {
let cOffset = '';
if (spatial) {
cOffset = `let cOffset = ${
yShape.length === 1 ? '0u' :
format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` :
'outputIndices[1]'};`;
} else {
if (format === 'NCHW') {
cOffset = `
${y.indicesSet('outputIndices', '0', '0')}
let cOffset = ${y.indicesToOffset('outputIndices')};`;
} else {
// update C channel.
cOffset = `var cIndices = ${scale.type.indices}(0);
cIndices[0] = outputIndices[${yShape.length - 1}];`;
// update D1 x ... x Dn channels.
for (let i = 1; i < scale.rank; i++) {
cOffset += `cIndices[${i}] = outputIndices[${i}];`;
}
cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`;
}
}
return cOffset;
};
const getInferenceModeShaderSource = (helper: ShaderHelper) => `
const epsilon = ${epsilon};
${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)}
${helper.mainStart()}
${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)};
${calcCOffset()}
let scale = ${scale.getByOffset('cOffset')};
let bias = ${bias.getByOffset('cOffset')};
let inputMean = ${inputMean.getByOffset('cOffset')};
let inputVar = ${inputVar.getByOffset('cOffset')};
let x = ${x.getByOffset('global_idx')};
let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias;
${y.setByOffset('global_idx', 'value')}
}`;
return {
name: 'BatchNormalization',
shaderCache: {
hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`,
inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined,
},
getShaderSource: getInferenceModeShaderSource,
getRunData: () => ({
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: outputSize},
...createTensorShapeVariables(yShape),
] :
[
{type: 'uint32', data: outputSize},
],
}),
};
};

export const parseBatchNormAttributes = (attributes: Record<string, unknown>): BatchNormAttributes =>
createAttributeWithCacheKey(attributes as Omit<BatchNormAttributes, keyof AttributeWithCacheKey>);

export const batchNorm = (context: ComputeContext, attributes: Record<string, unknown>): void => {
const {inputs, outputCount} = context;
const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount});
if (env.webgpu.validateInputContent) {
validateInputs(inputs, updatedAttributes);
}
if (attributes.trainingMode) {
throw new Error('BatchNormalization trainingMode is not supported yet.');
} else {
context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes));
}
};
Loading

0 comments on commit 64dacc2

Please sign in to comment.