Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Add Einsum for CPU, Wasm and Webgl #226

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ _This file is automatically generated from the def files via [this script](/tool
| [Div](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) |
| [Dropout](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Dropout) | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) |
| [DynamicQuantizeLinear](https://github.com/onnx/onnx/blob/master/docs/Operators.md#DynamicQuantizeLinear) | | | |
| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | | | |
| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) |
| [Elu](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Elu) | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) |
| [Equal](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal) | | | [7-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-7), [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-11) |
| [Erf](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Erf) | | | |
Expand Down
2 changes: 2 additions & 0 deletions lib/backends/cpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {CpuCast} from './ops/cast';
import {CpuConcat} from './ops/concat';
import {CpuConv} from './ops/conv';
import {CpuDropout} from './ops/dropout';
import {CpuEinsum} from './ops/einsum';
import {CpuExpand} from './ops/expand';
import {CpuFlatten} from './ops/flatten';
import {CpuGather} from './ops/gather';
Expand Down Expand Up @@ -112,4 +113,5 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Upsample', '', '7-8', () => new CpuUpsample()],
['Upsample', '', '9', () => new CpuUpsampleV9()],
['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))],
['Einsum', '', '12+', () => new CpuEinsum()],
];
50 changes: 50 additions & 0 deletions lib/backends/cpu/ops/einsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Einsum} from '../../../ops/einsum';
import {Tensor} from '../../../tensor';
import {CpuInferenceHandler} from '../inference-handler';

import {ShapeUtil} from './../../../util';

export class CpuEinsum extends Einsum {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs);

const result = einsum(outputShape, inputs, sizes, outputIndices, inputIndices);

return [result];
}
}

export function einsum(
outputShape: number[], inputs: Tensor[], sizes: number[], outputIndices: number[],
inputIndices: number[][]): Tensor {
const result = new Tensor(outputShape, inputs[0].type);
const totalSize = ShapeUtil.size(sizes);
let i = 0;
const index = new Array(sizes.length).fill(0);

while (i < totalSize) {
const outputIx: number[] = [];
for (const outputIndex of outputIndices) {
outputIx.push(index[outputIndex]);
}

let value = 1;
for (let i = 0; i < inputIndices.length; i++) {
const inputIx: number[] = [];
for (const inputIndex of inputIndices[i]) {
inputIx.push(index[inputIndex]);
}
value *= inputs[i].get(inputIx) as number;
}

result.set(outputIx, result.get(outputIx) as number + value);

i++;
ShapeUtil.incrementIndex(index, sizes);
}

return result;
}
2 changes: 2 additions & 0 deletions lib/backends/wasm/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {WasmBatchNormalization} from './ops/batch-normalization';
import {WasmBinaryOp} from './ops/binary-op';
import {WasmClip} from './ops/clip';
import {WasmConv} from './ops/conv';
import {WasmEinsum} from './ops/einsum';
import {WasmGemm} from './ops/gemm';
import {WasmInstanceNormalization} from './ops/instance-normalization';
import {WasmMatMul} from './ops/matmul';
Expand Down Expand Up @@ -36,4 +37,5 @@ export const WASM_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')],
['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8
['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')],
['Einsum', '', '12+', () => new WasmEinsum()],
];
63 changes: 63 additions & 0 deletions lib/backends/wasm/ops/einsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../../../attribute';
import {Einsum} from '../../../ops/einsum';
import {Tensor} from '../../../tensor';
import {WasmBinding} from '../../../wasm-binding';
import {WasmInferenceHandler} from '../inference-handler';

export class WasmEinsum extends Einsum {
initialize(attributes: Attribute): void {
super.initialize(attributes);
if (this.inputs.length > 2) {
throw new Error('Wasm implementation of Einsum currently supports at most 2 inputs');
}
}

run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs);

const y = new Tensor(outputShape, inputs[0].type);

if (inputs.length === 2) {
WasmBinding.getInstance().ccall(
'_einsum_f32',
[inputs[0].floatData, 'float32ptr'],
[inputs[1].floatData, 'float32ptr'],
[y.floatData, 'float32ptr', 'inout'],
[sizes, 'int32ptr'],
[sizes.length, 'int32'],
[outputIndices, 'int32ptr'],
[outputIndices.length, 'int32'],
[inputIndices[0], 'int32ptr'],
[inputIndices[0].length, 'int32'],
[inputIndices[1], 'int32ptr'],
[inputIndices[2].length, 'int32'],
);
} else {
WasmBinding.getInstance().ccall(
'_einsum_single_f32',
[inputs[0].floatData, 'float32ptr'],
[y.floatData, 'float32ptr', 'inout'],
[sizes, 'int32ptr'],
[sizes.length, 'int32'],
[outputIndices, 'int32ptr'],
[outputIndices.length, 'int32'],
[inputIndices[0], 'int32ptr'],
[inputIndices[1].length, 'int32'],
);
}

return [y];
}

checkInputTypes(inputs: Tensor[]): boolean {
// currently Wasm backend only supports 'float32' input type
if (inputs[0].type !== 'float32' || (inputs.length > 1 && inputs[1].type !== 'float32')) {
return false;
}

return super.checkInputTypes(inputs);
}
}
2 changes: 2 additions & 0 deletions lib/backends/webgl/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {WebGLClip} from './ops/clip';
import {WebGLConcat} from './ops/concat';
import {WebGLConv} from './ops/conv';
import {WebGLDropout} from './ops/dropout';
import {WebGLEinsum} from './ops/einsum';
import {WebGLElu} from './ops/elu';
import {WebGLFlatten} from './ops/flatten';
import {WebGLGather} from './ops/gather';
Expand Down Expand Up @@ -105,4 +106,5 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Upsample', '', '7-8', () => new WebGLUpsample()],
['Unsqueeze', '', '1+', () => new WebGLUnsqueeze()],
['Xor', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor())],
['Einsum', '', '1+', () => new WebGLEinsum()],
];
110 changes: 110 additions & 0 deletions lib/backends/webgl/ops/einsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Einsum} from '../../../ops/einsum';
import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';

import {ShapeUtil} from './../../../util';

const samplerNames = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'.split('');

export class WebGLEinsum extends Einsum implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return inferenceHandler.run(this, inputs);
}

createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs);

const sumDims = [];
const sumDimSizes = [];
for (let i = 0; i < sizes.length; i++) {
if (outputIndices.indexOf(i) === -1) {
sumDims.push(i);
sumDimSizes.push(sizes[i]);
}
}
const sumSize = ShapeUtil.size(sumDimSizes);

let rank = outputShape.length;
// Webgl doesnt like 0 length arrays
if (rank === 0) {
rank = 1;
}

const initIndex1 = outputIndices.map((x, i) => `index[${x}] = indices[${i}];`).join('\n');
const initIndex2 = sumDims.map(x => `index[${x}] = 0;`).join('\n');

const findInputValues = inputs.map((_, i) => this.buildFindInputValueScript(i, inputIndices[i])).join('\n');

const incrementIndex = this.buildIncrementIndexScript(sumDims, sumDimSizes);

const shaderSource = `
float process(int indices[${rank}]) {
float value = 0.0;

int index[${sizes.length}];
${initIndex1}
${initIndex2}

int i = 0;
while(i < ${sumSize}) {
float add = 1.0;

${findInputValues}

value += add;

${incrementIndex}
i++;
}

return value;
}`;
const inputLayouts = inputs.map(t => inferenceHandler.getOrCreateTextureLayout(t));
return {
inputLayouts,
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
samplers: samplerNames.slice(0, inputs.length),
shaderSource,
};
}

buildFindInputValueScript(i: number, indices: number[]): string {
const initInputIndex = indices.map((ix, indiceNum) => `input${i}Index[${indiceNum}] = index[${ix}];`).join('\n');

const script = `int input${i}Index[${indices.length}];
${initInputIndex}
add *= _${samplerNames[i]}(input${i}Index);`;

return script;
}

buildIncrementIndexScript(sumDims: number[], sumDimSizes: number[]): string {
let script = '';
for (let i = 0; i < sumDims.length; i++) {
script += `
index[${sumDims[i]}] += 1;
if (index[${sumDims[i]}] >= ${sumDimSizes[i]}) {
index[${sumDims[i]}] = 0;
`;
}
for (let i = 0; i < sumDims.length; i++) {
script += '}\n';
}

return script;
}

createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = inputs.map((v, i) => inferenceHandler.getOrCreateTextureData(v, programInfo.inputLayouts[i]));
return {
inputTextureDatas: inputTDs,
outputTextureData:
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
uniformData: {}
};
}
}
Loading