diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index a233d37a79e65..892f1b7f02141 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -4,9 +4,10 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -179,13 +180,15 @@ class EinsumEquation { const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { const dataType = inputs[0].dataType; - const inputVars = new Array(inputs.length); - for (let i = 0; i < inputs.length; ++i) { - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); - } + const enableInputShapesUniforms = inputs.map((input, _) => enableShapesUniforms(input.dims.length)); + const inputShapeOrRank = + inputs.map((input, index) => enableInputShapesUniforms[index] ? input.dims.length : input.dims); + const inputVars = inputs.map((_, index) => inputVariable(`input${index}`, dataType, inputShapeOrRank[index])); const outputShape = einsumEquation.outputDims; const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', dataType, outputShape); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const idxCopy: string[] = []; const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); const initProd = 'var prod = 1.0;'; @@ -249,21 +252,34 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: ...reduceOpsLoopFooters, ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} + ${shaderHelper.registerUniform('outputSize', 'u32').declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var outputIndices = ${output.offsetToIndices('global_idx')}; ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} ${reduceOps.join('\n')}; ${output.setByOffset('global_idx', 'sum')}; }`; + + const programUniforms: ProgramUniform[] = + inputs.filter((_, index) => enableInputShapesUniforms[index]) + .map((input, _) => [...createTensorShapeVariables(input.dims)]) + .reduce( + (acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), [{type: 'uint32', data: outputSize}]); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } return { name: 'Einsum', - shaderCache: {hint: einsumEquation.equation}, + shaderCache: { + hint: einsumEquation.equation, + inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') + }, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, };