Skip to content

Commit

Permalink
[js/webgpu] Optimize broadcast binary. (#18185)
Browse files Browse the repository at this point in the history
### Description
Currently, the binary algorithms are divided into the vectorize one
(efficient) and non-vectorize one (less efficient). Below situations
will go to the vectorize one:
1) A or B's shape length is 1.
2) The shared dimensions length of A and B are divisible by 4.
3) A and B have same shape.

This PR adds another situation as below to go to the vectorize
algorithm.
4. A or B's last dimension is divisible by 4.

With this change, the aggerate time of Add in sam-b-encoder becomes
309.65 ms from 409.12 ms on Intel ADL.
  • Loading branch information
qjia7 authored Nov 21, 2023
1 parent cc54202 commit abdf8b7
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{

const createBinaryOpProgramShader =
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number,
typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => {
vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall,
typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean,
additionalImplementation?: string) => {
let expressionScalar: BinaryCustomExpression;
let expressionVector: BinaryCustomExpression;
if (typeof funcCall === 'string') {
Expand All @@ -42,6 +43,8 @@ const createBinaryOpProgramShader =
if (doBroadcast) {
const isAOneElement = ShapeUtil.size(dimsA) === 1;
const isBOneElement = ShapeUtil.size(dimsB) === 1;
const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0;
const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0;
if (isAOneElement || isBOneElement) {
assignment = output.setByOffset(
'global_idx',
Expand All @@ -55,7 +58,14 @@ const createBinaryOpProgramShader =
let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)};
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
'global_idx',
expressionVector(
sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ?
a.getByOffset('offsetA / 4u') :
`${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`,
sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ?
b.getByOffset('offsetB / 4u') :
`${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))}
`;
}
} else {
Expand Down Expand Up @@ -118,6 +128,7 @@ const createBinaryOpProgramInfo =
let outputSize = ShapeUtil.size(a.dims);

let vectorize = false;
let sharedDimensionDivisibleBy4 = false;

// TODO: deal with zero-sized tensors (eg. dims=[1,0])
const cacheKeyAux = [isBroadcast];
Expand All @@ -130,8 +141,12 @@ const createBinaryOpProgramInfo =
outputSize = ShapeUtil.size(outputShape);
const isAOneElement = ShapeUtil.size(a.dims) === 1;
const isBOneElement = ShapeUtil.size(b.dims) === 1;
const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0;
const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0;
cacheKeyAux.push(isAOneElement);
cacheKeyAux.push(isBOneElement);
cacheKeyAux.push(aLastDimDivisibleBy4);
cacheKeyAux.push(bLastDimDivisibleBy4);
// check whether vectorize can be enabled
let sharedDimension = 1;
for (let i = 1; i < outputShape.length; i++) {
Expand All @@ -143,7 +158,10 @@ const createBinaryOpProgramInfo =
break;
}
}
if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) {
if (sharedDimension % 4 === 0) {
sharedDimensionDivisibleBy4 = true;
vectorize = true;
} else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) {
vectorize = true;
}
} else {
Expand All @@ -160,8 +178,8 @@ const createBinaryOpProgramInfo =
inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'],
},
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType,
outputDataType, useShapesUniforms, additionalImplementation),
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall,
a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
Expand Down

0 comments on commit abdf8b7

Please sign in to comment.