Skip to content

Commit

Permalink
[JS/Web] Add uniforms to Einsum (#18531)
Browse files Browse the repository at this point in the history
### Description
Add uinforms to Einsum



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve performance.
  • Loading branch information
satyajandhyala authored Nov 29, 2023
1 parent 483c490 commit 7335760
Show file tree
Hide file tree
Showing 2 changed files with 453 additions and 97 deletions.
220 changes: 126 additions & 94 deletions js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';

import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
Expand Down Expand Up @@ -101,7 +102,7 @@ class EinsumEquation {
this.outputDims.push(info.dimValue);
}
});
this.rhs = this.processTerm(rhs, true, this.outputDims);
this.rhs = this.processTerm(rhs, false, this.outputDims);
} // End of EinsumEqation constructor

// Add a symbol to the equation
Expand Down Expand Up @@ -157,12 +158,12 @@ class EinsumEquation {
}
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
for (let j = 0; j < ellipsisDims.length; j++) {
const symbol = String.fromCharCode('0'.charCodeAt(0) + i);
const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
einsumTerm.addSymbol(symbol, i + j);
this.addSymbol(symbol, dims[nextDim++], index);
}
} else {
einsumTerm.addSymbol(symbol, i);
einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
this.addSymbol(symbol, dims[nextDim++], index);
}
});
Expand All @@ -177,101 +178,132 @@ class EinsumEquation {
outputDims: number[]; // Output dimensions of the equation
} // End of class EinsumEquation

const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => {
const dataType = inputs[0].dataType;
const inputVars = new Array<IndicesHelper>(inputs.length);
for (let i = 0; i < inputs.length; ++i) {
inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims);
}
const outputShape = einsumEquation.outputDims;
const outputSize = ShapeUtil.size(outputShape);
const output = outputVariable('output', dataType, outputShape);
const idxCopy: string[] = [];
const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys());
const initProd = 'var prod = 1.0;';
const initSum = 'var sum = 0.0;';
const updateSum = 'sum += prod;';
const reduceOpsSetIndices: string[] = [];
const reduceOpsLoopHeaders: string[] = [];
const reduceOpsLoopFooters: string[] = [];
const reduceOpCompute: string[] = [];
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length;
einsumEquation.symbolToInfo.forEach((info, symbol) => {
if (rhsSymbols.includes(symbol)) {
const outputIndex = rhsSymbols.indexOf(symbol);
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
const appendMax = (name: string): string => name + '_max';

const createEinsumProgramInfo =
(enableInputShapesUniforms: readonly boolean[], inputShapes: Array<readonly number[]>, dataType: number,
einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => {
const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims);
const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank));
const outputSize = ShapeUtil.size(outputShape);
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);
const uniformsSymbols =
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const idxCopy: string[] = [];
const initProd = 'var prod = 1.0;';
const initSum = 'var sum = 0.0;';
const updateSum = 'sum += prod;';
const reduceOpsSetIndices: string[] = [];
const reduceOpsLoopHeaders: string[] = [];
const reduceOpsLoopFooters: string[] = [];
const reduceOpCompute: string[] = [];
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
einsumEquation.symbolToInfo.forEach((info, symbol) => {
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
if (outputIndex !== undefined) {
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
idxCopy.push(`${
inputVars[i].indicesSet(
`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
});
}
});
}
} else {
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
});
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
}
});
reduceOpsLoopHeaders.push(
`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`);
reduceOpsLoopFooters.push('}');
}
indices.forEach((index) => {
idxCopy.push(`${
inputVars[i].indicesSet(`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
});
}
});
} else {
einsumEquation.lhs.forEach((term, i) => {
const info = einsumEquation.symbolToInfo.get(symbol);
if (info === undefined) {
throw new Error('Invalid symbol error');
}
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
});
const reduceOps = isReduceOpsWithoutLoop ?
[
...idxCopy,
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
] :
[
...idxCopy,
initSum,
...reduceOpsLoopHeaders,
...reduceOpsSetIndices,
initProd,
...reduceOpCompute,
updateSum,
...reduceOpsLoopFooters,
];
return `
${
shaderHelper
.registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'})))
.registerUniform('outputSize', 'u32')
.declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;
};
return {
name: 'Einsum',
shaderCache: {
hint: einsumEquation.equation,
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
},
getRunData: () => {
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
// filter is added to make sure that dimValue is never 0.
const programUniformsInit: ProgramUniform[] =
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: 'uint32', data: outputSize});
const programUniforms: ProgramUniform[] =
inputShapes.filter((_, index) => enableInputShapesUniforms[index])
.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
indices.forEach((index) => {
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
return ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
});
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
}
});
reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${
einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`);
reduceOpsLoopFooters.push('}');
}
});
const reduceOps = isReduceOpsWithoutLoop ?
[
...idxCopy,
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
] :
[
...idxCopy,
initSum,
...reduceOpsLoopHeaders,
...reduceOpsSetIndices,
initProd,
...reduceOpCompute,
updateSum,
...reduceOpsLoopFooters,
];
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;
return {
name: 'Einsum',
shaderCache: {hint: einsumEquation.equation},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
}),
getShaderSource,
};
};
},
getShaderSource,
};
};

export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
context.compute(createEinsumProgramInfo(context.inputs, einsumEquation));
const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length));
const outputShape = einsumEquation.outputDims;
const inputShapes = context.inputs.map((input, _) => input.dims);
context.compute(createEinsumProgramInfo(
enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
};

export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
Expand Down
Loading

0 comments on commit 7335760

Please sign in to comment.