From 2b62e1e4d52d4e6d2d8b11f082f7a6b514d482af Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 18 Sep 2023 16:50:18 +0800 Subject: [PATCH] [js/webgpu] Allow binary ops with scalar to use the vectorize path --- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 13d3a91bb339e..20a421c4ee5d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -62,14 +62,24 @@ const createBinaryOpProgramShader = let assignment: string; if (vectorize) { if (doBroadcast) { - assignment = ` + const isAScalar = dimsA.length === 0; + const isBScalar = dimsB.length === 0; + if (isAScalar || isBScalar) { + assignment = output.setByOffset( + 'global_idx', + expressionVector( + isAScalar ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'), + isBScalar ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'))); + } else { + assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; let offsetA = calcOffsetA(outputIndices); let offsetB = calcOffsetB(outputIndices); ${ - output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + output.setByOffset( + 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} `; + } } else { assignment = output.setByOffset( 'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'))); @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo = } outputShape = calculatedShape; outputSize = ShapeUtil.size(outputShape); + const isAScalar = a.dims.length === 0; + const isBScalar = b.dims.length === 0; // check whether vectorize can be enabled let sharedDimension = 1; @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0) { + if (sharedDimension % 4 === 0 || isAScalar || isBScalar) { vectorize = true; } } else { @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo = shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, outputDataType, additionalImplementation), outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => - ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}) }; };