Skip to content

Commit

Permalink
Add uniforms to Einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Nov 21, 2023
1 parent 247ce21 commit e18eabc
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';

import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
Expand Down Expand Up @@ -179,13 +180,15 @@ class EinsumEquation {

const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => {
const dataType = inputs[0].dataType;
const inputVars = new Array<IndicesHelper>(inputs.length);
for (let i = 0; i < inputs.length; ++i) {
inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims);
}
const enableInputShapesUniforms = inputs.map((input, _) => enableShapesUniforms(input.dims.length));
const inputShapeOrRank =
inputs.map((input, index) => enableInputShapesUniforms[index] ? input.dims.length : input.dims);
const inputVars = inputs.map((_, index) => inputVariable(`input${index}`, dataType, inputShapeOrRank[index]));
const outputShape = einsumEquation.outputDims;
const outputSize = ShapeUtil.size(outputShape);
const output = outputVariable('output', dataType, outputShape);
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);
const idxCopy: string[] = [];
const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys());
const initProd = 'var prod = 1.0;';
Expand Down Expand Up @@ -249,21 +252,34 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation:
...reduceOpsLoopFooters,
];
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(...inputVars, output)}
${shaderHelper.registerUniform('outputSize', 'u32').declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;

const programUniforms: ProgramUniform[] =
inputs.filter((_, index) => enableInputShapesUniforms[index])
.map((input, _) => [...createTensorShapeVariables(input.dims)])
.reduce(
(acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), [{type: 'uint32', data: outputSize}]);
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
return {
name: 'Einsum',
shaderCache: {hint: einsumEquation.equation},
shaderCache: {
hint: einsumEquation.equation,
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
};
Expand Down

0 comments on commit e18eabc

Please sign in to comment.