Skip to content

Commit

Permalink
[js/webgpu] Allow binary ops with scalar to use the vectorize path (m…
Browse files Browse the repository at this point in the history
…icrosoft#17589)

### Description
1. For binary ops, the components is always 4. So the dispatchGroup
should be : `{x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /*
component size */)}` instead of `{x: Math.ceil(outputSize / 64 /*
workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}`.

2. If any of a or b only has one element, we still can use the vectorize
path since the same value will be broadcasted.
  • Loading branch information
qjia7 authored Sep 22, 2023
1 parent 52b5f0c commit 5d67df2
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@ const createBinaryOpProgramShader =
let assignment: string;
if (vectorize) {
if (doBroadcast) {
assignment = `
const isAOneElement = ShapeUtil.size(dimsA) === 1;
const isBOneElement = ShapeUtil.size(dimsB) === 1;
if (isAOneElement || isBOneElement) {
assignment = output.setByOffset(
'global_idx',
expressionVector(
isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'),
isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx')));
} else {
assignment = `
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffsetA(outputIndices);
let offsetB = calcOffsetB(outputIndices);
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
`;
}
} else {
assignment = output.setByOffset(
'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')));
Expand Down Expand Up @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo =
}
outputShape = calculatedShape;
outputSize = ShapeUtil.size(outputShape);
const isAOneElement = ShapeUtil.size(a.dims) === 1;
const isBOneElement = ShapeUtil.size(b.dims) === 1;

// check whether vectorize can be enabled
let sharedDimension = 1;
Expand All @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo =
break;
}
}
if (sharedDimension % 4 === 0) {
if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) {
vectorize = true;
}
} else {
Expand All @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo =
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType,
outputDataType, additionalImplementation),
outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () =>
({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)})
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)})
};
};

Expand Down

0 comments on commit 5d67df2

Please sign in to comment.