-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[js/webgpu] Add BatchNormalization Op (#18468)
### Description This PR adds `BatchNormalization` with `float` support. Some Todos: 1. all inputs don't have same data type. For example, x/y is float16, but bias/scale is float32 or double. 2. training mode support. We see many models are using `BatchNormalization` ops. However, due to the missing in jsep, all of them run on cpu, which result very poor performance. With this PR's support, densenet-9 model becomes 20.29 ms from 250.69 ms.
- Loading branch information
Showing
9 changed files
with
688 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {env} from 'onnxruntime-common'; | ||
|
||
import {TensorView} from '../../tensor-view'; | ||
import {ShapeUtil} from '../../util'; | ||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; | ||
import {ComputeContext, ProgramInfo} from '../types'; | ||
|
||
import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; | ||
|
||
export interface BatchNormAttributes extends AttributeWithCacheKey { | ||
readonly epsilon: number; | ||
readonly momentum: number; | ||
readonly spatial: boolean; | ||
readonly trainingMode: boolean; | ||
readonly format: 'NHWC'|'NCHW'; | ||
readonly outputCount: number; | ||
} | ||
|
||
const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => { | ||
if (!inputs || inputs.length !== 5) { | ||
throw new Error('BatchNormalization requires 5 inputs'); | ||
} | ||
|
||
const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => { | ||
const r = expected.length; | ||
if (r !== actual.length) { | ||
throw new Error(`${message}: num dimensions != ${r}`); | ||
} | ||
expected.forEach((v, i) => { | ||
if (v !== actual[i]) { | ||
throw new Error(`${message}: dim[${i}] do not match`); | ||
} | ||
}); | ||
}; | ||
|
||
if (inputs[0].dims.length > 1) { | ||
const shape = attributes.format === 'NHWC' ? | ||
(attributes.spatial ? inputs[0].dims.slice(-1) : | ||
inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) : | ||
inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); | ||
checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale'); | ||
checkShapeEqual(inputs[2].dims, shape, 'Invalid input B'); | ||
checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean'); | ||
checkShapeEqual(inputs[4].dims, shape, 'Invalid input var'); | ||
} else { | ||
checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale'); | ||
checkShapeEqual(inputs[2].dims, [1], 'Invalid input B'); | ||
checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean'); | ||
checkShapeEqual(inputs[4].dims, [1], 'Invalid input var'); | ||
} | ||
}; | ||
|
||
const createBatchNormInferenceProgramInfo = | ||
(inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => { | ||
const {epsilon, spatial, format} = attributes; | ||
const yShape = inputs[0].dims; | ||
const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; | ||
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; | ||
const outputSize = ShapeUtil.size(yShape) / components; | ||
// Only support uniforms for opset version >= 9 (spatial = true). | ||
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; | ||
const shapeOrRank = useShapesUniforms ? yShape.length : yShape; | ||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); | ||
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); | ||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); | ||
const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); | ||
const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); | ||
const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); | ||
// TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. | ||
// Otherwise, the shader compilation will fail. | ||
const calcCOffset = (): string => { | ||
let cOffset = ''; | ||
if (spatial) { | ||
cOffset = `let cOffset = ${ | ||
yShape.length === 1 ? '0u' : | ||
format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` : | ||
'outputIndices[1]'};`; | ||
} else { | ||
if (format === 'NCHW') { | ||
cOffset = ` | ||
${y.indicesSet('outputIndices', '0', '0')} | ||
let cOffset = ${y.indicesToOffset('outputIndices')};`; | ||
} else { | ||
// update C channel. | ||
cOffset = `var cIndices = ${scale.type.indices}(0); | ||
cIndices[0] = outputIndices[${yShape.length - 1}];`; | ||
// update D1 x ... x Dn channels. | ||
for (let i = 1; i < scale.rank; i++) { | ||
cOffset += `cIndices[${i}] = outputIndices[${i}];`; | ||
} | ||
cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; | ||
} | ||
} | ||
return cOffset; | ||
}; | ||
const getInferenceModeShaderSource = (helper: ShaderHelper) => ` | ||
const epsilon = ${epsilon}; | ||
${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)} | ||
${helper.mainStart()} | ||
${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} | ||
var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)}; | ||
${calcCOffset()} | ||
let scale = ${scale.getByOffset('cOffset')}; | ||
let bias = ${bias.getByOffset('cOffset')}; | ||
let inputMean = ${inputMean.getByOffset('cOffset')}; | ||
let inputVar = ${inputVar.getByOffset('cOffset')}; | ||
let x = ${x.getByOffset('global_idx')}; | ||
let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias; | ||
${y.setByOffset('global_idx', 'value')} | ||
}`; | ||
return { | ||
name: 'BatchNormalization', | ||
shaderCache: { | ||
hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, | ||
inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, | ||
}, | ||
getShaderSource: getInferenceModeShaderSource, | ||
getRunData: () => ({ | ||
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], | ||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, | ||
programUniforms: useShapesUniforms ? | ||
[ | ||
{type: 'uint32', data: outputSize}, | ||
...createTensorShapeVariables(yShape), | ||
] : | ||
[ | ||
{type: 'uint32', data: outputSize}, | ||
], | ||
}), | ||
}; | ||
}; | ||
|
||
export const parseBatchNormAttributes = (attributes: Record<string, unknown>): BatchNormAttributes => | ||
createAttributeWithCacheKey(attributes as Omit<BatchNormAttributes, keyof AttributeWithCacheKey>); | ||
|
||
export const batchNorm = (context: ComputeContext, attributes: Record<string, unknown>): void => { | ||
const {inputs, outputCount} = context; | ||
const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount}); | ||
if (env.webgpu.validateInputContent) { | ||
validateInputs(inputs, updatedAttributes); | ||
} | ||
if (attributes.trainingMode) { | ||
throw new Error('BatchNormalization trainingMode is not supported yet.'); | ||
} else { | ||
context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes)); | ||
} | ||
}; |
Oops, something went wrong.