Skip to content

Commit

Permalink
[js/webgpu] Allow binary ops with scalar to use the vectorize path
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Sep 18, 2023
1 parent c969237 commit 2b62e1e
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@ const createBinaryOpProgramShader =
let assignment: string;
if (vectorize) {
if (doBroadcast) {
assignment = `
const isAScalar = dimsA.length === 0;
const isBScalar = dimsB.length === 0;
if (isAScalar || isBScalar) {
assignment = output.setByOffset(
'global_idx',
expressionVector(
isAScalar ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'),
isBScalar ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx')));
} else {
assignment = `
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffsetA(outputIndices);
let offsetB = calcOffsetB(outputIndices);
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
`;
}
} else {
assignment = output.setByOffset(
'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')));
Expand Down Expand Up @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo =
}
outputShape = calculatedShape;
outputSize = ShapeUtil.size(outputShape);
const isAScalar = a.dims.length === 0;
const isBScalar = b.dims.length === 0;

// check whether vectorize can be enabled
let sharedDimension = 1;
Expand All @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo =
break;
}
}
if (sharedDimension % 4 === 0) {
if (sharedDimension % 4 === 0 || isAScalar || isBScalar) {
vectorize = true;
}
} else {
Expand All @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo =
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType,
outputDataType, additionalImplementation),
outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () =>
({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)})
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)})
};
};

Expand Down

0 comments on commit 2b62e1e

Please sign in to comment.