diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index eab571e87f5f5..0992552ebaf25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -18,10 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, additionalImplementation?: string) => { - const outputSize = ShapeUtil.size(dimsOutput); - const vecSize = Math.ceil(outputSize / 4); - + typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -33,31 +30,12 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - let broadcastImpl = ''; - const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', typeA, dimsA, 4); - const b = inputVariable('bData', typeB, dimsB, 4); - if (doBroadcast) { - const calcOffsetImpl = (dims: readonly number[]) => { - const strides = ShapeUtil.computeStrides(dims); - const offsets: string[] = []; - for (let i = dims.length - 1; i >= 0; i--) { - const idx = output.indicesGet('outputIndices', i + dimsOutput.length - dims.length); - offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); - } - return offsets.length > 0 ? offsets.join('+') : '0u'; - }; - - broadcastImpl = ` - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsA)}; - } - - fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsB)}; - } - `; - } + const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; + const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; + const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; + const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); + const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); + const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); let assignment: string; if (vectorize) { @@ -73,8 +51,8 @@ const createBinaryOpProgramShader = } else { assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; - let offsetA = calcOffsetA(outputIndices); - let offsetB = calcOffsetB(outputIndices); + let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)}; + let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; ${ output.setByOffset( 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} @@ -94,8 +72,8 @@ const createBinaryOpProgramShader = const expressionB = `bData[indexB${x}][componentB${x}]`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = calcOffsetA(outputIndices${x}); - let offsetB${x} = calcOffsetB(outputIndices${x}); + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let componentA${x} = offsetA${x} % 4u; @@ -122,13 +100,12 @@ const createBinaryOpProgramShader = } return ` - ${shaderHelper.declareVariables(a, b, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)} ${additionalImplementation ?? ''} - ${broadcastImpl} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; }; @@ -144,6 +121,7 @@ const createBinaryOpProgramInfo = // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + const cacheKeyAux = [isBroadcast]; if (isBroadcast) { const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); if (!calculatedShape) { @@ -153,7 +131,8 @@ const createBinaryOpProgramInfo = outputSize = ShapeUtil.size(outputShape); const isAOneElement = ShapeUtil.size(a.dims) === 1; const isBOneElement = ShapeUtil.size(b.dims) === 1; - + cacheKeyAux.push(isAOneElement); + cacheKeyAux.push(isBOneElement); // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { @@ -172,16 +151,34 @@ const createBinaryOpProgramInfo = // element-wise vectorize = true; } - + cacheKeyAux.push(vectorize); + const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && + enableShapesUniforms(outputShape.length); return { name, - shaderCache: {hint: cacheKey}, + shaderCache: { + hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), + // If the input is scalar then use type instead of dims because useShapesUniforms is false. + inputDependencies: useShapesUniforms ? + ['rank', 'rank'] : + [a.dims.length > 0 ? 'dims' : 'type', b.dims.length > 0 ? 'dims' : 'type'], + }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, additionalImplementation), + outputDataType, useShapesUniforms, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims), + ...createTensorShapeVariables(b.dims), + ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 1d3fc78fe368a..7352517733c3e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -805,4 +805,4 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly }; // TODO: remove this limitation once >4D dims are supported by uniform. -export const enableShapesUniforms = (rank: number): boolean => rank <= 4; +export const enableShapesUniforms = (rank: number): boolean => rank <= 4 && rank > 0;