Skip to content

Commit

Permalink
Fixed the reduced symbols on the left-hand side to use uniforms value…
Browse files Browse the repository at this point in the history
…s for foo-loop limit in the shader code.
  • Loading branch information
satyajandhyala committed Nov 26, 2023
1 parent a514758 commit 663ed49
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
25 changes: 18 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';

Check notice

Code scanning / CodeQL

Unused variable, import, function or class Note

Unused import UniformsArrayType.


export interface EinsumAttributes extends AttributeWithCacheKey {
Expand Down Expand Up @@ -197,6 +197,7 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation:
const reduceOpsLoopHeaders: string[] = [];
const reduceOpsLoopFooters: string[] = [];
const reduceOpCompute: string[] = [];
const uniformsSymbols: string[] = []; // Equations symbols that require dim limit added to Uniforms.
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
einsumEquation.symbolToInfo.forEach((info, symbol) => {
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
Expand Down Expand Up @@ -233,8 +234,8 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation:
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
}
});
reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${
einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`);
reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${symbol}_max; ${symbol}++) {`);
uniformsSymbols.push(symbol)
reduceOpsLoopFooters.push('}');
}
});
Expand All @@ -254,21 +255,31 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation:
...reduceOpsLoopFooters,
];
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('outputSize', 'u32').declareVariables(...inputVars, output)}
${
shaderHelper.registerUniforms(uniformsSymbols.map((symbol) => ({name: `${symbol}_max`, type: 'u32'})))
.registerUniform('outputSize', 'u32')
.declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;

// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The filter
// is added to make sure that dimValue is never 0.
const programUniformsInit: ProgramUniform[] =
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: 'uint32', data: outputSize});

const programUniforms: ProgramUniform[] =
inputs.filter((_, index) => enableInputShapesUniforms[index])
.map((input, _) => createTensorShapeVariables(input.dims))
.reduce(
(acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), [{type: 'uint32', data: outputSize}]);
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);

if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
Expand Down
24 changes: 23 additions & 1 deletion js/web/test/data/ops/einsum.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
],
"cases": [
{
"name": "Multiply",
"name": "Multiply (2,3) X (3,4) -> (2,4)",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
Expand All @@ -269,6 +269,28 @@
"type": "float32"
}
]
},
{
"name": "Multiply (2,6) X (6,4) -> (2,4)",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ,10, 11],
"dims": [2, 6],
"type": "float32"
},
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
"dims": [6, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [220, 235, 250, 265, 580, 631, 682, 733],
"dims": [2, 4],
"type": "float32"
}
]
}
]
},
Expand Down

0 comments on commit 663ed49

Please sign in to comment.