From 7335760424b052ff041285571cf52b77f9ebb009 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Wed, 29 Nov 2023 15:30:33 -0800 Subject: [PATCH] [JS/Web] Add uniforms to Einsum (#18531) ### Description Add uinforms to Einsum ### Motivation and Context Improve performance. --- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 220 +++++++++------ js/web/test/data/ops/einsum.jsonc | 330 +++++++++++++++++++++- 2 files changed, 453 insertions(+), 97 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index a233d37a79e65..4db7c04ad67be 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -4,9 +4,10 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -101,7 +102,7 @@ class EinsumEquation { this.outputDims.push(info.dimValue); } }); - this.rhs = this.processTerm(rhs, true, this.outputDims); + this.rhs = this.processTerm(rhs, false, this.outputDims); } // End of EinsumEqation constructor // Add a symbol to the equation @@ -157,12 +158,12 @@ class EinsumEquation { } // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling for (let j = 0; j < ellipsisDims.length; j++) { - const symbol = String.fromCharCode('0'.charCodeAt(0) + i); + const symbol = String.fromCharCode('0'.charCodeAt(0) + j); einsumTerm.addSymbol(symbol, i + j); this.addSymbol(symbol, dims[nextDim++], index); } } else { - einsumTerm.addSymbol(symbol, i); + einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0)); this.addSymbol(symbol, dims[nextDim++], index); } }); @@ -177,101 +178,132 @@ class EinsumEquation { outputDims: number[]; // Output dimensions of the equation } // End of class EinsumEquation -const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { - const dataType = inputs[0].dataType; - const inputVars = new Array(inputs.length); - for (let i = 0; i < inputs.length; ++i) { - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); - } - const outputShape = einsumEquation.outputDims; - const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', dataType, outputShape); - const idxCopy: string[] = []; - const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); - const initProd = 'var prod = 1.0;'; - const initSum = 'var sum = 0.0;'; - const updateSum = 'sum += prod;'; - const reduceOpsSetIndices: string[] = []; - const reduceOpsLoopHeaders: string[] = []; - const reduceOpsLoopFooters: string[] = []; - const reduceOpCompute: string[] = []; - const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length; - einsumEquation.symbolToInfo.forEach((info, symbol) => { - if (rhsSymbols.includes(symbol)) { - const outputIndex = rhsSymbols.indexOf(symbol); - einsumEquation.lhs.forEach((term, i) => { - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); +const appendMax = (name: string): string => name + '_max'; + +const createEinsumProgramInfo = + (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, + einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { + const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); + const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + const outputSize = ShapeUtil.size(outputShape); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); + const uniformsSymbols = + [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; + const initProd = 'var prod = 1.0;'; + const initSum = 'var sum = 0.0;'; + const updateSum = 'sum += prod;'; + const reduceOpsSetIndices: string[] = []; + const reduceOpsLoopHeaders: string[] = []; + const reduceOpsLoopFooters: string[] = []; + const reduceOpCompute: string[] = []; + const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size; + einsumEquation.symbolToInfo.forEach((info, symbol) => { + if (einsumEquation.rhs.symbolToIndices.has(symbol)) { + const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0]; + if (outputIndex !== undefined) { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + idxCopy.push(`${ + inputVars[i].indicesSet( + `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); + }); + } + }); + } + } else { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + }); + reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); + } + }); + reduceOpsLoopHeaders.push( + `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`); + reduceOpsLoopFooters.push('}'); } - indices.forEach((index) => { - idxCopy.push(`${ - inputVars[i].indicesSet(`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); - }); - } - }); - } else { - einsumEquation.lhs.forEach((term, i) => { - const info = einsumEquation.symbolToInfo.get(symbol); - if (info === undefined) { - throw new Error('Invalid symbol error'); - } - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); + }); + const reduceOps = isReduceOpsWithoutLoop ? + [ + ...idxCopy, + `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` + ] : + [ + ...idxCopy, + initSum, + ...reduceOpsLoopHeaders, + ...reduceOpsSetIndices, + initProd, + ...reduceOpCompute, + updateSum, + ...reduceOpsLoopFooters, + ]; + return ` + ${ + shaderHelper + .registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'}))) + .registerUniform('outputSize', 'u32') + .declareVariables(...inputVars, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var outputIndices = ${output.offsetToIndices('global_idx')}; + ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${reduceOps.join('\n')}; + ${output.setByOffset('global_idx', 'sum')}; + }`; + }; + return { + name: 'Einsum', + shaderCache: { + hint: einsumEquation.equation, + inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') + }, + getRunData: () => { + // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The + // filter is added to make sure that dimValue is never 0. + const programUniformsInit: ProgramUniform[] = + uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) + .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: 'uint32', data: outputSize}); + const programUniforms: ProgramUniform[] = + inputShapes.filter((_, index) => enableInputShapesUniforms[index]) + .map((dims, _) => [...createTensorShapeVariables(dims)]) + .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); } - indices.forEach((index) => { - reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + return ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }); - reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); - } - }); - reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${ - einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`); - reduceOpsLoopFooters.push('}'); - } - }); - const reduceOps = isReduceOpsWithoutLoop ? - [ - ...idxCopy, - `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` - ] : - [ - ...idxCopy, - initSum, - ...reduceOpsLoopHeaders, - ...reduceOpsSetIndices, - initProd, - ...reduceOpCompute, - updateSum, - ...reduceOpsLoopFooters, - ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - var outputIndices = ${output.offsetToIndices('global_idx')}; - ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} - ${reduceOps.join('\n')}; - ${output.setByOffset('global_idx', 'sum')}; - }`; - return { - name: 'Einsum', - shaderCache: {hint: einsumEquation.equation}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} - }), - getShaderSource, - }; -}; + }, + getShaderSource, + }; + }; export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - context.compute(createEinsumProgramInfo(context.inputs, einsumEquation)); + const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); + const outputShape = einsumEquation.outputDims; + const inputShapes = context.inputs.map((input, _) => input.dims); + context.compute(createEinsumProgramInfo( + enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/test/data/ops/einsum.jsonc b/js/web/test/data/ops/einsum.jsonc index baf30cf982148..45bba6a121bd1 100644 --- a/js/web/test/data/ops/einsum.jsonc +++ b/js/web/test/data/ops/einsum.jsonc @@ -171,7 +171,7 @@ ], "cases": [ { - "name": "Diagonal elementwise multiplication", + "name": "Diagonal elements dot product", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -210,7 +210,7 @@ ], "cases": [ { - "name": "Dotproduct", + "name": "diagonal elements multiplication", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -233,6 +233,240 @@ } ] }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,ij -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 5, 0, 0, 0, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dot product/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "outer product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 2, 4, 6, 3, 6, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,ij -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 5, 0, 0, 0, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dot product/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "outer product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 2, 4, 6, 3, 6, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "einsum", "operator": "Einsum", @@ -249,7 +483,7 @@ ], "cases": [ { - "name": "Multiply", + "name": "Multiply (2,3) X (3,4) -> (2,4)", "inputs": [ { "data": [1, 2, 3, 4, 5, 6], @@ -269,6 +503,28 @@ "type": "float32" } ] + }, + { + "name": "Multiply (2,6) X (6,4) -> (2,4)", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [6, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [220, 235, 250, 265, 580, 631, 682, 733], + "dims": [2, 4], + "type": "float32" + } + ] } ] }, @@ -631,5 +887,73 @@ ] } ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ijk->ikj", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with 3 dims", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij->...ji", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with ellipsis with input/output dims > 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 1, 1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ]