From d17bd4db64c4f52a910e6f7d0aee0641b68af909 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 12:27:51 +0200 Subject: [PATCH 01/10] Add first version of einsum cpu implementation --- lib/backends/cpu/op-resolve-rules.ts | 2 + lib/backends/cpu/ops/einsum.ts | 65 ++++++++++++++ lib/ops/einsum.ts | 114 +++++++++++++++++++++++++ test/data/ops/einsum.jsonc | 121 +++++++++++++++++++++++++++ 4 files changed, 302 insertions(+) create mode 100644 lib/backends/cpu/ops/einsum.ts create mode 100644 lib/ops/einsum.ts create mode 100644 test/data/ops/einsum.jsonc diff --git a/lib/backends/cpu/op-resolve-rules.ts b/lib/backends/cpu/op-resolve-rules.ts index c77a8b7e..2a978b01 100644 --- a/lib/backends/cpu/op-resolve-rules.ts +++ b/lib/backends/cpu/op-resolve-rules.ts @@ -11,6 +11,7 @@ import {CpuCast} from './ops/cast'; import {CpuConcat} from './ops/concat'; import {CpuConv} from './ops/conv'; import {CpuDropout} from './ops/dropout'; +import {CpuEinsum} from './ops/einsum'; import {CpuExpand} from './ops/expand'; import {CpuFlatten} from './ops/flatten'; import {CpuGather} from './ops/gather'; @@ -112,4 +113,5 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray = [ ['Upsample', '', '7-8', () => new CpuUpsample()], ['Upsample', '', '9', () => new CpuUpsampleV9()], ['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))], + ['Einsum', '', '12+', () => new CpuEinsum()], ]; diff --git a/lib/backends/cpu/ops/einsum.ts b/lib/backends/cpu/ops/einsum.ts new file mode 100644 index 00000000..49cc8c83 --- /dev/null +++ b/lib/backends/cpu/ops/einsum.ts @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Einsum} from '../../../ops/einsum'; +import {Tensor} from '../../../tensor'; +import {CpuInferenceHandler} from '../inference-handler'; + +import {ShapeUtil} from './../../../util'; + +export class CpuEinsum extends Einsum { + run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] { + const dimensionSizeMap: {[id: string]: number} = {}; + this.matchInputs(inputs, dimensionSizeMap); + const outputShape = this.calculateOutputSize(dimensionSizeMap); + + const result = new Tensor(outputShape, inputs[0].type); + + let i = 0; + const index = []; + const sizes = []; + const nameToId: {[name: string]: number} = {}; + const idToName: {[id: number]: string} = {}; + for (const name in dimensionSizeMap) { + index.push(0); + sizes.push(dimensionSizeMap[name]); + nameToId[name] = i; + idToName[i] = name; + i++; + } + + const totalSize = ShapeUtil.size(sizes); + i = 0; + + while (i < totalSize) { + const outputIx: number[] = []; + for (const outputName of this.outputIndices) { + outputIx.push(index[nameToId[outputName]]); + } + + const input1Ix: number[] = []; + for (const input1Name of this.input1Indices) { + input1Ix.push(index[nameToId[input1Name]]); + } + let value = inputs[0].get(input1Ix) as number; + if (this.input2) { + const input2Ix: number[] = []; + for (const input2Name of this.input2Indices) { + input2Ix.push(index[nameToId[input2Name]]); + } + value *= inputs[1].get(input2Ix) as number; + } + + result.set(outputIx, result.get(outputIx) as number + value); + + i++; + ShapeUtil.incrementIndex(index, sizes); + } + + return [result]; + } +} + +export function einsum(a: Tensor, b: Tensor) { + return undefined; +} diff --git a/lib/ops/einsum.ts b/lib/ops/einsum.ts new file mode 100644 index 00000000..300ff1e3 --- /dev/null +++ b/lib/ops/einsum.ts @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Attribute} from '../attribute'; +import {InferenceHandler} from '../backend'; +import {Operator} from '../operators'; +import {Tensor} from '../tensor'; + +export abstract class Einsum implements Operator { + abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + + initialize(attributes: Attribute): void { + this.equation = attributes.getString('equation'); + const split = this.equation.split('->'); + this.lhs = split[0].trim(); + if (split.length === 2) { + this.rhs = split[1].trim(); + this.implicit = false; + } else { + this.implicit = true; + } + + const lhsSplit = this.lhs.split(','); + this.input1 = lhsSplit[0].trim(); + if (lhsSplit.length === 2) { + this.input2 = lhsSplit[1].trim(); + } + + this.parseEquationPart(this.input1, this.input1Indices); + if (this.input2) { + this.parseEquationPart(this.input2, this.input2Indices); + } + if (this.rhs) { + this.parseEquationPart(this.rhs, this.outputIndices); + } + } + + private parseEquationPart(part: string, indices: string[]) { + for (let i = 0; i < part.length; i++) { + const char = part.charAt(i); + + if (char === '.') { + throw new Error('Use of ellipsis (...) in einsum not yet supported'); + } + + indices.push(char); + } + } + + protected matchInputs(inputs: Tensor[], dimensionMap: {[id: string]: number}) { + this.matchDimensions(this.input1Indices, inputs[0].dims, dimensionMap); + if (this.input2) { + this.matchDimensions(this.input2Indices, inputs[1].dims, dimensionMap); + } + } + + protected calculateOutputSize(dimensionMap: {[id: string]: number}): number[] { + if (this.outputIndices.length === 0) { + return []; + } + + const result: number[] = []; + for (let i = 0; i < this.outputIndices.length; i++) { + result.push(dimensionMap[this.outputIndices[i]]); + } + return result; + } + + checkInputs(inputs: Tensor[]): boolean { + const dimensionMap: {[id: string]: number} = {}; + + this.matchDimensions(this.input1Indices, inputs[0].dims, dimensionMap); + if (this.input2) { + this.matchDimensions(this.input2Indices, inputs[1].dims, dimensionMap); + } + + return this.checkInputTypes(inputs); + } + + protected matchDimensions(indices: string[], inputDims: readonly number[], dimensionMap: {[id: string]: number}): + boolean { + for (let j = 0; j < indices.length; j++) { + const ix = indices[j]; + if (dimensionMap[ix] && dimensionMap[ix] !== inputDims[j]) { + return false; + } else if (!dimensionMap[ix]) { + dimensionMap[ix] = inputDims[j]; + } + } + + return true; + } + + protected checkInputTypes(inputs: Tensor[]): boolean { + return true; + } + + protected equation: string; + protected lhs: string; + protected rhs?: string; + + protected input1: string; + protected input2?: string; + + // Maps from input 1 axis to general axis id + protected input1Indices: string[] = []; + // Maps from input 2 axis to general axis id + protected input2Indices: string[] = []; + + // Maps from output axis to general axis id + protected outputIndices: string[] = []; + + protected implicit: boolean; +} diff --git a/test/data/ops/einsum.jsonc b/test/data/ops/einsum.jsonc new file mode 100644 index 00000000..c8a345bc --- /dev/null +++ b/test/data/ops/einsum.jsonc @@ -0,0 +1,121 @@ +[ + { + "name": "Einsum batch matmul", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [ + { "name": "equation", "data": "bij, bjk -> bik", "type": "string" } + ], + "cases": [ + { + "name": "BMM", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 2, 2], + "type": "float32" + }, + { + "data": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24 + ], + "dims": [3, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 11, + 14, + 17, + 20, + 23, + 30, + 37, + 44, + 123, + 134, + 145, + 156, + 167, + 182, + 197, + 212, + 363, + 382, + 401, + 420, + 439, + 462, + 485, + 508 + ], + "dims": [3, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum transpose", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [ + { "name": "equation", "data": "ij -> ji", "type": "string" } + ], + "cases": [ + { + "name": "BMM", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12], + "dims": [4, 3], + "type": "float32" + } + ] + } + ] + } +] From da89002e79bb54e883a649d3b6d218726a74b958 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 12:33:41 +0200 Subject: [PATCH 02/10] Add more einsum operator tests --- test/data/ops/einsum.jsonc | 114 +++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/test/data/ops/einsum.jsonc b/test/data/ops/einsum.jsonc index c8a345bc..c4fb8ebb 100644 --- a/test/data/ops/einsum.jsonc +++ b/test/data/ops/einsum.jsonc @@ -117,5 +117,119 @@ ] } ] + }, + { + "name": "Einsum inner product", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [{ "name": "equation", "data": "i,i", "type": "string" }], + "cases": [ + { + "name": "BMM", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [55], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum batch diagonal", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [{ "name": "equation", "data": "bii->bi", "type": "string" }], + "cases": [ + { + "name": "BMM", + "inputs": [ + { + "data": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18 + ], + "dims": [2, 3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 10, 14, 18], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum sum", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [{ "name": "equation", "data": "ij->i", "type": "string" }], + "cases": [ + { + "name": "BMM", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 26, 42], + "dims": [3], + "type": "float32" + } + ] + } + ] } ] From f538c7817ba5d18362afb000a5b033da45bee196 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 12:52:00 +0200 Subject: [PATCH 03/10] Refactor einsum implementation --- lib/backends/cpu/ops/einsum.ts | 81 +++++++++++++++++++++------------- lib/ops/einsum.ts | 52 ++++++++++++++-------- 2 files changed, 83 insertions(+), 50 deletions(-) diff --git a/lib/backends/cpu/ops/einsum.ts b/lib/backends/cpu/ops/einsum.ts index 49cc8c83..cb0ece54 100644 --- a/lib/backends/cpu/ops/einsum.ts +++ b/lib/backends/cpu/ops/einsum.ts @@ -9,57 +9,76 @@ import {ShapeUtil} from './../../../util'; export class CpuEinsum extends Einsum { run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] { - const dimensionSizeMap: {[id: string]: number} = {}; + const dimensionSizeMap: {[name: string]: number} = {}; this.matchInputs(inputs, dimensionSizeMap); const outputShape = this.calculateOutputSize(dimensionSizeMap); - const result = new Tensor(outputShape, inputs[0].type); - let i = 0; - const index = []; const sizes = []; const nameToId: {[name: string]: number} = {}; const idToName: {[id: number]: string} = {}; + for (const name in dimensionSizeMap) { - index.push(0); sizes.push(dimensionSizeMap[name]); nameToId[name] = i; idToName[i] = name; i++; } - const totalSize = ShapeUtil.size(sizes); - i = 0; - - while (i < totalSize) { - const outputIx: number[] = []; - for (const outputName of this.outputIndices) { - outputIx.push(index[nameToId[outputName]]); + const outputIndices: number[] = []; + const input1Indices: number[] = []; + const input2Indices: number[] = []; + for (const outputName of this.outputNames) { + outputIndices.push(nameToId[outputName]); + } + for (const inputName of this.input1Names) { + input1Indices.push(nameToId[inputName]); + } + if (this.input2) { + for (const inputName of this.input2Names) { + input2Indices.push(nameToId[inputName]); } + } - const input1Ix: number[] = []; - for (const input1Name of this.input1Indices) { - input1Ix.push(index[nameToId[input1Name]]); - } - let value = inputs[0].get(input1Ix) as number; - if (this.input2) { - const input2Ix: number[] = []; - for (const input2Name of this.input2Indices) { - input2Ix.push(index[nameToId[input2Name]]); - } - value *= inputs[1].get(input2Ix) as number; - } + const result = + einsum(outputShape, inputs, sizes, outputIndices, input1Indices, this.input2 ? input2Indices : undefined); + + return [result]; + } +} - result.set(outputIx, result.get(outputIx) as number + value); +export function einsum( + outputShape: number[], inputs: Tensor[], sizes: number[], outputIndices: number[], input1Indices: number[], + input2Indices?: number[]): Tensor { + const result = new Tensor(outputShape, inputs[0].type); + const totalSize = ShapeUtil.size(sizes); + let i = 0; + const index = new Array(sizes.length).fill(0); - i++; - ShapeUtil.incrementIndex(index, sizes); + while (i < totalSize) { + const outputIx: number[] = []; + for (const outputIndex of outputIndices) { + outputIx.push(index[outputIndex]); } - return [result]; + const input1Ix: number[] = []; + for (const input1Index of input1Indices) { + input1Ix.push(index[input1Index]); + } + let value = inputs[0].get(input1Ix) as number; + if (input2Indices) { + const input2Ix: number[] = []; + for (const input2Index of input2Indices) { + input2Ix.push(index[input2Index]); + } + value *= inputs[1].get(input2Ix) as number; + } + + result.set(outputIx, result.get(outputIx) as number + value); + + i++; + ShapeUtil.incrementIndex(index, sizes); } -} -export function einsum(a: Tensor, b: Tensor) { - return undefined; + return result; } diff --git a/lib/ops/einsum.ts b/lib/ops/einsum.ts index 300ff1e3..4d93dc4b 100644 --- a/lib/ops/einsum.ts +++ b/lib/ops/einsum.ts @@ -26,12 +26,12 @@ export abstract class Einsum implements Operator { this.input2 = lhsSplit[1].trim(); } - this.parseEquationPart(this.input1, this.input1Indices); + this.parseEquationPart(this.input1, this.input1Names); if (this.input2) { - this.parseEquationPart(this.input2, this.input2Indices); + this.parseEquationPart(this.input2, this.input2Names); } if (this.rhs) { - this.parseEquationPart(this.rhs, this.outputIndices); + this.parseEquationPart(this.rhs, this.outputNames); } } @@ -47,21 +47,17 @@ export abstract class Einsum implements Operator { } } - protected matchInputs(inputs: Tensor[], dimensionMap: {[id: string]: number}) { - this.matchDimensions(this.input1Indices, inputs[0].dims, dimensionMap); + protected matchInputs(inputs: Tensor[], dimensionSizeMap: {[name: string]: number}) { + this.matchDimensions(this.input1Names, inputs[0].dims, dimensionSizeMap); if (this.input2) { - this.matchDimensions(this.input2Indices, inputs[1].dims, dimensionMap); + this.matchDimensions(this.input2Names, inputs[1].dims, dimensionSizeMap); } } - protected calculateOutputSize(dimensionMap: {[id: string]: number}): number[] { - if (this.outputIndices.length === 0) { - return []; - } - + protected calculateOutputSize(dimensionSizeMap: {[name: string]: number}): number[] { const result: number[] = []; - for (let i = 0; i < this.outputIndices.length; i++) { - result.push(dimensionMap[this.outputIndices[i]]); + for (let i = 0; i < this.outputNames.length; i++) { + result.push(dimensionSizeMap[this.outputNames[i]]); } return result; } @@ -69,9 +65,18 @@ export abstract class Einsum implements Operator { checkInputs(inputs: Tensor[]): boolean { const dimensionMap: {[id: string]: number} = {}; - this.matchDimensions(this.input1Indices, inputs[0].dims, dimensionMap); - if (this.input2) { - this.matchDimensions(this.input2Indices, inputs[1].dims, dimensionMap); + if (inputs.length < 1 || inputs.length > 2) { + return false; + } + + if (!this.matchDimensions(this.input1Names, inputs[0].dims, dimensionMap)) { + return false; + } + + if (this.input2 && inputs.length < 2) { + return false; + } else if (this.input2 && !this.matchDimensions(this.input2Names, inputs[1].dims, dimensionMap)) { + return false; } return this.checkInputTypes(inputs); @@ -92,6 +97,15 @@ export abstract class Einsum implements Operator { } protected checkInputTypes(inputs: Tensor[]): boolean { + const allowedTypes = ['float32', 'float64', 'int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32']; + if (allowedTypes.indexOf(inputs[0].type) === -1 || + (inputs.length > 1 && allowedTypes.indexOf(inputs[1].type) === -1)) { + return false; + } + + if (inputs.length > 1 && inputs[0].type !== inputs[1].type) { + return false; + } return true; } @@ -103,12 +117,12 @@ export abstract class Einsum implements Operator { protected input2?: string; // Maps from input 1 axis to general axis id - protected input1Indices: string[] = []; + protected input1Names: string[] = []; // Maps from input 2 axis to general axis id - protected input2Indices: string[] = []; + protected input2Names: string[] = []; // Maps from output axis to general axis id - protected outputIndices: string[] = []; + protected outputNames: string[] = []; protected implicit: boolean; } From b51e6ba7d21c0ec6635d35ae38c72677825e1cc5 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 13:44:18 +0200 Subject: [PATCH 04/10] Add wasm implementation of einsum --- lib/backends/wasm/op-resolve-rules.ts | 2 + lib/backends/wasm/ops/einsum.ts | 84 ++++++++++++++++ src/wasm-build-config.json | 4 +- src/wasm-ops/einsum.cpp | 132 ++++++++++++++++++++++++++ src/wasm-ops/einsum.h | 20 ++++ src/wasm-ops/utils/shape_utils.cpp | 12 +++ src/wasm-ops/utils/shape_utils.h | 3 + 7 files changed, 256 insertions(+), 1 deletion(-) create mode 100644 lib/backends/wasm/ops/einsum.ts create mode 100644 src/wasm-ops/einsum.cpp create mode 100644 src/wasm-ops/einsum.h diff --git a/lib/backends/wasm/op-resolve-rules.ts b/lib/backends/wasm/op-resolve-rules.ts index 92fb6141..cee2e91d 100644 --- a/lib/backends/wasm/op-resolve-rules.ts +++ b/lib/backends/wasm/op-resolve-rules.ts @@ -1,3 +1,4 @@ +import { WasmEinsum } from './ops/einsum'; // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. @@ -36,4 +37,5 @@ export const WASM_OP_RESOLVE_RULES: ReadonlyArray = [ ['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')], ['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8 ['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')], + ['Einsum', '', '12+', () => new WasmEinsum()], // TODO: support multidirectional broadcast for Sum-8 ]; diff --git a/lib/backends/wasm/ops/einsum.ts b/lib/backends/wasm/ops/einsum.ts new file mode 100644 index 00000000..16bf7e81 --- /dev/null +++ b/lib/backends/wasm/ops/einsum.ts @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Einsum} from '../../../ops/einsum'; +import {Tensor} from '../../../tensor'; +import {WasmBinding} from '../../../wasm-binding'; +import {WasmInferenceHandler} from '../inference-handler'; + +export class WasmEinsum extends Einsum { + run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { + const dimensionSizeMap: {[name: string]: number} = {}; + this.matchInputs(inputs, dimensionSizeMap); + const outputShape = this.calculateOutputSize(dimensionSizeMap); + + let i = 0; + const sizes = []; + const nameToId: {[name: string]: number} = {}; + const idToName: {[id: number]: string} = {}; + + for (const name in dimensionSizeMap) { + sizes.push(dimensionSizeMap[name]); + nameToId[name] = i; + idToName[i] = name; + i++; + } + + const outputIndices: number[] = []; + const input1Indices: number[] = []; + const input2Indices: number[] = []; + for (const outputName of this.outputNames) { + outputIndices.push(nameToId[outputName]); + } + for (const inputName of this.input1Names) { + input1Indices.push(nameToId[inputName]); + } + if (this.input2) { + for (const inputName of this.input2Names) { + input2Indices.push(nameToId[inputName]); + } + } + + const y = new Tensor(outputShape, inputs[0].type); + + if (inputs.length === 2) { + WasmBinding.getInstance().ccall( + '_einsum_f32', + [inputs[0].floatData, 'float32ptr'], + [inputs[1].floatData, 'float32ptr'], + [y.floatData, 'float32ptr', 'inout'], + [sizes, 'int32ptr'], + [sizes.length, 'int32'], + [outputIndices, 'int32ptr'], + [outputIndices.length, 'int32'], + [input1Indices, 'int32ptr'], + [input1Indices.length, 'int32'], + [input1Indices, 'int32ptr'], + [input2Indices.length, 'int32'], + ); + } else { + WasmBinding.getInstance().ccall( + '_einsum_single_f32', + [inputs[0].floatData, 'float32ptr'], + [y.floatData, 'float32ptr', 'inout'], + [sizes, 'int32ptr'], + [sizes.length, 'int32'], + [outputIndices, 'int32ptr'], + [outputIndices.length, 'int32'], + [input1Indices, 'int32ptr'], + [input1Indices.length, 'int32'], + ); + } + + return [y]; + } + + checkInputTypes(inputs: Tensor[]): boolean { + // currently Wasm backend only supports 'float32' input type + if (inputs[0].type !== 'float32' || (inputs.length > 1 && inputs[1].type !== 'float32')) { + return false; + } + + return super.checkInputTypes(inputs); + } +} diff --git a/src/wasm-build-config.json b/src/wasm-build-config.json index 95a99ee8..5f6a5f92 100644 --- a/src/wasm-build-config.json +++ b/src/wasm-build-config.json @@ -22,6 +22,8 @@ "_clip_f32", "_instance_normalization_f32", "_sum_f32", - "_softmax_f32" + "_softmax_f32", + "_einsum_f32", + "_einsum_single_f32" ] } diff --git a/src/wasm-ops/einsum.cpp b/src/wasm-ops/einsum.cpp new file mode 100644 index 00000000..bdc0fb68 --- /dev/null +++ b/src/wasm-ops/einsum.cpp @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "einsum.h" +#include "common.h" +#include "utils/shape_utils.h" + +// Wasm interop method +void einsum_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + uint32_t const argc = dataIndex[0]; + + const float *a = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *b = PARAM_FLOAT_PTR(data, dataIndex[2]); + float *y = PARAM_FLOAT_PTR(data, dataIndex[3]); + const int32_t *dims = PARAM_INT32_PTR(data, dataIndex[4]); + const int32_t rank = PARAM_INT32(data, dataIndex[5]); + const int32_t *outputIndices = PARAM_INT32_PTR(data, dataIndex[6]); + const int32_t outputRank = PARAM_INT32(data, dataIndex[7]); + const int32_t *input1Indices = PARAM_INT32_PTR(data, dataIndex[8]); + const int32_t input1Rank = PARAM_INT32(data, dataIndex[9]); + const int32_t *input2Indices = PARAM_INT32_PTR(data, dataIndex[10]); + const int32_t input2Rank = PARAM_INT32(data, dataIndex[11]); + + einsum_f32_imp(a, b, y, dims, rank, outputIndices, outputRank, input1Indices, + input1Rank, input2Indices, input2Rank); +} + +void einsum_f32_imp(const float *A, const float *B, float *Y, + const int32_t *dims, const int32_t rank, + const int32_t *outputIndices, int32_t outputRank, + const int32_t *input1Indices, int32_t input1Rank, + const int32_t *input2Indices, int32_t input2Rank) { + std::vector dimsVector(dims, dims + rank); + // std::vector strides = ShapeUtils::compute_strides(dimsVector); + size_t totalSize = ShapeUtils::size_from_dims(dimsVector); + size_t i = 0; + std::vector index(rank, 0); + + std::vector outputStrides(outputRank, 1); + for (size_t j = outputRank - 2; j >= 0; j--) { + outputStrides[j] = outputStrides[j + 1] * dimsVector[outputIndices[j]]; + } + + std::vector input1Strides(input1Rank, 1); + for (size_t j = input1Rank - 2; j >= 0; j--) { + input1Strides[j] = input1Strides[j + 1] * dimsVector[input1Indices[j]]; + } + + std::vector input2Strides(input2Rank, 1); + for (size_t j = input2Rank - 2; j >= 0; j--) { + input2Strides[j] = input2Strides[j + 1] * dimsVector[input2Indices[j]]; + } + + while (i < totalSize) { + size_t outputOffset = 0; + for (size_t j = 0; j < outputRank; j++) { + outputOffset += index[outputIndices[j]] * outputStrides[j]; + } + + size_t input1Offset = 0; + for (size_t j = 0; j < input1Rank; j++) { + input1Offset += index[input1Indices[j]] * input1Strides[j]; + } + + size_t input2Offset = 0; + for (size_t j = 0; j < input2Rank; j++) { + input2Offset += index[input2Indices[j]] * input2Strides[j]; + } + + Y[outputOffset] += A[input1Offset] * B[input2Offset]; + + i++; + ShapeUtils::increment_index(index, dimsVector, dimsVector.size()); + } +} + +void einsum_single_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + uint32_t const argc = dataIndex[0]; + + const float *a = PARAM_FLOAT_PTR(data, dataIndex[1]); + float *y = PARAM_FLOAT_PTR(data, dataIndex[2]); + const int32_t *dims = PARAM_INT32_PTR(data, dataIndex[3]); + const int32_t rank = PARAM_INT32(data, dataIndex[4]); + const int32_t *outputIndices = PARAM_INT32_PTR(data, dataIndex[5]); + const int32_t outputRank = PARAM_INT32(data, dataIndex[6]); + const int32_t *inputIndices = PARAM_INT32_PTR(data, dataIndex[7]); + const int32_t inputRank = PARAM_INT32(data, dataIndex[8]); + + einsum_single_f32_imp(a, y, dims, rank, outputIndices, outputRank, + inputIndices, inputRank); +} + +// Core operator implementation +void einsum_single_f32_imp(const float *A, float *Y, const int32_t *dims, + const int32_t rank, const int32_t *outputIndices, + int32_t outputRank, const int32_t *inputIndices, + int32_t inputRank) { + std::vector dimsVector(dims, dims + rank); + // std::vector strides = ShapeUtils::compute_strides(dimsVector); + size_t totalSize = ShapeUtils::size_from_dims(dimsVector); + size_t i = 0; + std::vector index(rank, 0); + + std::vector outputStrides(outputRank, 1); + for (size_t j = outputRank - 2; j >= 0; j--) { + outputStrides[j] = outputStrides[j + 1] * dimsVector[outputIndices[j]]; + } + + std::vector inputStrides(inputRank, 1); + for (size_t j = inputRank - 2; j >= 0; j--) { + inputStrides[j] = inputStrides[j + 1] * dimsVector[inputIndices[j]]; + } + + while (i < totalSize) { + size_t outputOffset = 0; + for (size_t j = 0; j < outputRank; j++) { + outputOffset += index[outputIndices[j]] * outputStrides[j]; + } + + size_t input1Offset = 0; + for (size_t j = 0; j < inputRank; j++) { + input1Offset += index[inputIndices[j]] * inputStrides[j]; + } + + Y[outputOffset] += A[input1Offset]; + + i++; + ShapeUtils::increment_index(index, dimsVector, dimsVector.size()); + } +} diff --git a/src/wasm-ops/einsum.h b/src/wasm-ops/einsum.h new file mode 100644 index 00000000..017845cc --- /dev/null +++ b/src/wasm-ops/einsum.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +extern "C" { +void einsum_f32(void *); +void einsum_f32_imp(const float *A, const float *B, float *Y, + const int32_t *dims, const int32_t rank, + const int32_t *outputIndices, int32_t outputRank, + const int32_t *input1Indices, int32_t input1Rank, + const int32_t *input2Indices, int32_t input2Rank); +void einsum_single_f32(void *); +void einsum_single_f32_imp(const float *A, float *Y, const int32_t *dims, + const int32_t rank, const int32_t *outputIndices, + int32_t outputRank, const int32_t *inputIndices, + int32_t inputRank); +} diff --git a/src/wasm-ops/utils/shape_utils.cpp b/src/wasm-ops/utils/shape_utils.cpp index bd068b83..7d2e9e1d 100644 --- a/src/wasm-ops/utils/shape_utils.cpp +++ b/src/wasm-ops/utils/shape_utils.cpp @@ -86,3 +86,15 @@ void ShapeUtils::offset_to_indices(const std::vector &strides, } indices[indices.size() - 1] = offset; } + +void ShapeUtils::increment_index(std::vector &index, + const std::vector &dims, + size_t axisToIncrementOn) { + for (size_t i = axisToIncrementOn - 1; i >= 0; --i) { + index[i]++; + if (index[i] < dims[i]) { + break; + } + index[i] = 0; + } +} diff --git a/src/wasm-ops/utils/shape_utils.h b/src/wasm-ops/utils/shape_utils.h index 2d138136..9de9dd4b 100644 --- a/src/wasm-ops/utils/shape_utils.h +++ b/src/wasm-ops/utils/shape_utils.h @@ -18,4 +18,7 @@ std::vector offset_to_indices(const std::vector &strides, // Fills in values in the indices vector. Assumes it is of the required size. void offset_to_indices(const std::vector &strides, size_t offset, std::vector &indices); +void increment_index(std::vector &index, + const std::vector &dims, + size_t axisToIncrementOn); }; // namespace ShapeUtils From 22b4d129e2cf9954cbba83387a4cd8ade59da8f4 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 16:14:33 +0200 Subject: [PATCH 05/10] Format documents --- docs/operators.md | 2 +- lib/backends/wasm/op-resolve-rules.ts | 4 ++-- test/data/ops/einsum.jsonc | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/operators.md b/docs/operators.md index 47b30cee..66ecab43 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -43,7 +43,7 @@ _This file is automatically generated from the def files via [this script](/tool | [Div](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | | [Dropout](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Dropout) | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [DynamicQuantizeLinear](https://github.com/onnx/onnx/blob/master/docs/Operators.md#DynamicQuantizeLinear) | | | | -| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | | | | +| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | | | [Elu](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Elu) | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [Equal](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal) | | | [7-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-7), [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-11) | | [Erf](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Erf) | | | | diff --git a/lib/backends/wasm/op-resolve-rules.ts b/lib/backends/wasm/op-resolve-rules.ts index cee2e91d..8da862b7 100644 --- a/lib/backends/wasm/op-resolve-rules.ts +++ b/lib/backends/wasm/op-resolve-rules.ts @@ -1,4 +1,3 @@ -import { WasmEinsum } from './ops/einsum'; // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. @@ -8,6 +7,7 @@ import {WasmBatchNormalization} from './ops/batch-normalization'; import {WasmBinaryOp} from './ops/binary-op'; import {WasmClip} from './ops/clip'; import {WasmConv} from './ops/conv'; +import {WasmEinsum} from './ops/einsum'; import {WasmGemm} from './ops/gemm'; import {WasmInstanceNormalization} from './ops/instance-normalization'; import {WasmMatMul} from './ops/matmul'; @@ -37,5 +37,5 @@ export const WASM_OP_RESOLVE_RULES: ReadonlyArray = [ ['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')], ['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8 ['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')], - ['Einsum', '', '12+', () => new WasmEinsum()], // TODO: support multidirectional broadcast for Sum-8 + ['Einsum', '', '12+', () => new WasmEinsum()], ]; diff --git a/test/data/ops/einsum.jsonc b/test/data/ops/einsum.jsonc index c4fb8ebb..d51d9c01 100644 --- a/test/data/ops/einsum.jsonc +++ b/test/data/ops/einsum.jsonc @@ -100,7 +100,7 @@ ], "cases": [ { - "name": "BMM", + "name": "Transpose", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], @@ -130,7 +130,7 @@ "attributes": [{ "name": "equation", "data": "i,i", "type": "string" }], "cases": [ { - "name": "BMM", + "name": "Inner product", "inputs": [ { "data": [1, 2, 3, 4, 5], @@ -165,7 +165,7 @@ "attributes": [{ "name": "equation", "data": "bii->bi", "type": "string" }], "cases": [ { - "name": "BMM", + "name": "Diagonal", "inputs": [ { "data": [ @@ -214,7 +214,7 @@ "attributes": [{ "name": "equation", "data": "ij->i", "type": "string" }], "cases": [ { - "name": "BMM", + "name": "Sum", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], From 1e757fd79403693fa255a638e201826f44c244a2 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 19:18:35 +0200 Subject: [PATCH 06/10] Generalize einsum to variadic inputs --- lib/backends/cpu/ops/einsum.ts | 37 +++++++++++-------------- lib/backends/wasm/ops/einsum.ts | 34 +++++++++++++---------- lib/ops/einsum.ts | 48 ++++++++++++++------------------- 3 files changed, 55 insertions(+), 64 deletions(-) diff --git a/lib/backends/cpu/ops/einsum.ts b/lib/backends/cpu/ops/einsum.ts index cb0ece54..13d33572 100644 --- a/lib/backends/cpu/ops/einsum.ts +++ b/lib/backends/cpu/ops/einsum.ts @@ -26,30 +26,27 @@ export class CpuEinsum extends Einsum { } const outputIndices: number[] = []; - const input1Indices: number[] = []; - const input2Indices: number[] = []; + const inputIndices: number[][] = []; for (const outputName of this.outputNames) { outputIndices.push(nameToId[outputName]); } - for (const inputName of this.input1Names) { - input1Indices.push(nameToId[inputName]); - } - if (this.input2) { - for (const inputName of this.input2Names) { - input2Indices.push(nameToId[inputName]); + for (let i = 0; i < this.inputs.length; i++) { + const indices = []; + for (const inputName of this.inputNames[i]) { + indices.push(nameToId[inputName]); } + inputIndices.push(indices); } - const result = - einsum(outputShape, inputs, sizes, outputIndices, input1Indices, this.input2 ? input2Indices : undefined); + const result = einsum(outputShape, inputs, sizes, outputIndices, inputIndices); return [result]; } } export function einsum( - outputShape: number[], inputs: Tensor[], sizes: number[], outputIndices: number[], input1Indices: number[], - input2Indices?: number[]): Tensor { + outputShape: number[], inputs: Tensor[], sizes: number[], outputIndices: number[], + inputIndices: number[][]): Tensor { const result = new Tensor(outputShape, inputs[0].type); const totalSize = ShapeUtil.size(sizes); let i = 0; @@ -61,17 +58,13 @@ export function einsum( outputIx.push(index[outputIndex]); } - const input1Ix: number[] = []; - for (const input1Index of input1Indices) { - input1Ix.push(index[input1Index]); - } - let value = inputs[0].get(input1Ix) as number; - if (input2Indices) { - const input2Ix: number[] = []; - for (const input2Index of input2Indices) { - input2Ix.push(index[input2Index]); + let value = 1; + for (let i = 0; i < inputIndices.length; i++) { + const inputIx: number[] = []; + for (const inputIndex of inputIndices[i]) { + inputIx.push(index[inputIndex]); } - value *= inputs[1].get(input2Ix) as number; + value *= inputs[i].get(inputIx) as number; } result.set(outputIx, result.get(outputIx) as number + value); diff --git a/lib/backends/wasm/ops/einsum.ts b/lib/backends/wasm/ops/einsum.ts index 16bf7e81..f728685d 100644 --- a/lib/backends/wasm/ops/einsum.ts +++ b/lib/backends/wasm/ops/einsum.ts @@ -1,12 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +import {Attribute} from '../../../attribute'; import {Einsum} from '../../../ops/einsum'; import {Tensor} from '../../../tensor'; import {WasmBinding} from '../../../wasm-binding'; import {WasmInferenceHandler} from '../inference-handler'; export class WasmEinsum extends Einsum { + initialize(attributes: Attribute): void { + super.initialize(attributes); + if (this.inputs.length > 2) { + throw new Error('Wasm implementation of Einsum currently supports at most 2 inputs'); + } + } + run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { const dimensionSizeMap: {[name: string]: number} = {}; this.matchInputs(inputs, dimensionSizeMap); @@ -25,18 +33,16 @@ export class WasmEinsum extends Einsum { } const outputIndices: number[] = []; - const input1Indices: number[] = []; - const input2Indices: number[] = []; + const inputIndices: number[][] = []; for (const outputName of this.outputNames) { outputIndices.push(nameToId[outputName]); } - for (const inputName of this.input1Names) { - input1Indices.push(nameToId[inputName]); - } - if (this.input2) { - for (const inputName of this.input2Names) { - input2Indices.push(nameToId[inputName]); + for (let i = 0; i < this.inputs.length; i++) { + const indices = []; + for (const inputName of this.inputNames[i]) { + indices.push(nameToId[inputName]); } + inputIndices.push(indices); } const y = new Tensor(outputShape, inputs[0].type); @@ -51,10 +57,10 @@ export class WasmEinsum extends Einsum { [sizes.length, 'int32'], [outputIndices, 'int32ptr'], [outputIndices.length, 'int32'], - [input1Indices, 'int32ptr'], - [input1Indices.length, 'int32'], - [input1Indices, 'int32ptr'], - [input2Indices.length, 'int32'], + [inputIndices[0], 'int32ptr'], + [inputIndices[0].length, 'int32'], + [inputIndices[1], 'int32ptr'], + [inputIndices[2].length, 'int32'], ); } else { WasmBinding.getInstance().ccall( @@ -65,8 +71,8 @@ export class WasmEinsum extends Einsum { [sizes.length, 'int32'], [outputIndices, 'int32ptr'], [outputIndices.length, 'int32'], - [input1Indices, 'int32ptr'], - [input1Indices.length, 'int32'], + [inputIndices[0], 'int32ptr'], + [inputIndices[1].length, 'int32'], ); } diff --git a/lib/ops/einsum.ts b/lib/ops/einsum.ts index 4d93dc4b..b1d4dbe9 100644 --- a/lib/ops/einsum.ts +++ b/lib/ops/einsum.ts @@ -21,15 +21,13 @@ export abstract class Einsum implements Operator { } const lhsSplit = this.lhs.split(','); - this.input1 = lhsSplit[0].trim(); - if (lhsSplit.length === 2) { - this.input2 = lhsSplit[1].trim(); - } + this.inputs = lhsSplit.map(v => v.trim()); - this.parseEquationPart(this.input1, this.input1Names); - if (this.input2) { - this.parseEquationPart(this.input2, this.input2Names); + for (let i = 0; i < this.inputs.length; i++) { + this.inputNames.push([]); + this.parseEquationPart(this.inputs[i], this.inputNames[i]); } + if (this.rhs) { this.parseEquationPart(this.rhs, this.outputNames); } @@ -48,9 +46,8 @@ export abstract class Einsum implements Operator { } protected matchInputs(inputs: Tensor[], dimensionSizeMap: {[name: string]: number}) { - this.matchDimensions(this.input1Names, inputs[0].dims, dimensionSizeMap); - if (this.input2) { - this.matchDimensions(this.input2Names, inputs[1].dims, dimensionSizeMap); + for (let i = 0; i < inputs.length; i++) { + this.matchDimensions(this.inputNames[i], inputs[i].dims, dimensionSizeMap); } } @@ -65,18 +62,14 @@ export abstract class Einsum implements Operator { checkInputs(inputs: Tensor[]): boolean { const dimensionMap: {[id: string]: number} = {}; - if (inputs.length < 1 || inputs.length > 2) { - return false; - } - - if (!this.matchDimensions(this.input1Names, inputs[0].dims, dimensionMap)) { + if (inputs.length !== this.inputs.length) { return false; } - if (this.input2 && inputs.length < 2) { - return false; - } else if (this.input2 && !this.matchDimensions(this.input2Names, inputs[1].dims, dimensionMap)) { - return false; + for (let i = 0; i < inputs.length; i++) { + if (!this.matchDimensions(this.inputNames[i], inputs[i].dims, dimensionMap)) { + return false; + } } return this.checkInputTypes(inputs); @@ -98,14 +91,16 @@ export abstract class Einsum implements Operator { protected checkInputTypes(inputs: Tensor[]): boolean { const allowedTypes = ['float32', 'float64', 'int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32']; - if (allowedTypes.indexOf(inputs[0].type) === -1 || - (inputs.length > 1 && allowedTypes.indexOf(inputs[1].type) === -1)) { + + if (inputs.find((v) => allowedTypes.indexOf(v.type) === -1) !== undefined) { return false; } - if (inputs.length > 1 && inputs[0].type !== inputs[1].type) { + const types = inputs.map(v => v.type) + if (types.find(v => v !== types[0]) !== undefined) { return false; } + return true; } @@ -113,13 +108,10 @@ export abstract class Einsum implements Operator { protected lhs: string; protected rhs?: string; - protected input1: string; - protected input2?: string; + protected inputs: string[] = []; - // Maps from input 1 axis to general axis id - protected input1Names: string[] = []; - // Maps from input 2 axis to general axis id - protected input2Names: string[] = []; + // The i-th string[] Maps from input axis i to general axis id + protected inputNames: string[][] = []; // Maps from output axis to general axis id protected outputNames: string[] = []; From 144c10691832b6f3c91a54b282c73760f02fdc44 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 20:00:23 +0200 Subject: [PATCH 07/10] Add webgl backend for einsum --- lib/backends/cpu/ops/einsum.ts | 29 +------ lib/backends/wasm/ops/einsum.ts | 29 +------ lib/backends/webgl/op-resolve-rules.ts | 2 + lib/backends/webgl/ops/einsum.ts | 110 +++++++++++++++++++++++++ lib/ops/einsum.ts | 33 ++++++++ 5 files changed, 147 insertions(+), 56 deletions(-) create mode 100644 lib/backends/webgl/ops/einsum.ts diff --git a/lib/backends/cpu/ops/einsum.ts b/lib/backends/cpu/ops/einsum.ts index 13d33572..6e3378d4 100644 --- a/lib/backends/cpu/ops/einsum.ts +++ b/lib/backends/cpu/ops/einsum.ts @@ -9,34 +9,7 @@ import {ShapeUtil} from './../../../util'; export class CpuEinsum extends Einsum { run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] { - const dimensionSizeMap: {[name: string]: number} = {}; - this.matchInputs(inputs, dimensionSizeMap); - const outputShape = this.calculateOutputSize(dimensionSizeMap); - - let i = 0; - const sizes = []; - const nameToId: {[name: string]: number} = {}; - const idToName: {[id: number]: string} = {}; - - for (const name in dimensionSizeMap) { - sizes.push(dimensionSizeMap[name]); - nameToId[name] = i; - idToName[i] = name; - i++; - } - - const outputIndices: number[] = []; - const inputIndices: number[][] = []; - for (const outputName of this.outputNames) { - outputIndices.push(nameToId[outputName]); - } - for (let i = 0; i < this.inputs.length; i++) { - const indices = []; - for (const inputName of this.inputNames[i]) { - indices.push(nameToId[inputName]); - } - inputIndices.push(indices); - } + const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs); const result = einsum(outputShape, inputs, sizes, outputIndices, inputIndices); diff --git a/lib/backends/wasm/ops/einsum.ts b/lib/backends/wasm/ops/einsum.ts index f728685d..49c6c194 100644 --- a/lib/backends/wasm/ops/einsum.ts +++ b/lib/backends/wasm/ops/einsum.ts @@ -16,34 +16,7 @@ export class WasmEinsum extends Einsum { } run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { - const dimensionSizeMap: {[name: string]: number} = {}; - this.matchInputs(inputs, dimensionSizeMap); - const outputShape = this.calculateOutputSize(dimensionSizeMap); - - let i = 0; - const sizes = []; - const nameToId: {[name: string]: number} = {}; - const idToName: {[id: number]: string} = {}; - - for (const name in dimensionSizeMap) { - sizes.push(dimensionSizeMap[name]); - nameToId[name] = i; - idToName[i] = name; - i++; - } - - const outputIndices: number[] = []; - const inputIndices: number[][] = []; - for (const outputName of this.outputNames) { - outputIndices.push(nameToId[outputName]); - } - for (let i = 0; i < this.inputs.length; i++) { - const indices = []; - for (const inputName of this.inputNames[i]) { - indices.push(nameToId[inputName]); - } - inputIndices.push(indices); - } + const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs); const y = new Tensor(outputShape, inputs[0].type); diff --git a/lib/backends/webgl/op-resolve-rules.ts b/lib/backends/webgl/op-resolve-rules.ts index 7b9d56fe..ec208f71 100644 --- a/lib/backends/webgl/op-resolve-rules.ts +++ b/lib/backends/webgl/op-resolve-rules.ts @@ -10,6 +10,7 @@ import {WebGLClip} from './ops/clip'; import {WebGLConcat} from './ops/concat'; import {WebGLConv} from './ops/conv'; import {WebGLDropout} from './ops/dropout'; +import {WebGLEinsum} from './ops/einsum'; import {WebGLElu} from './ops/elu'; import {WebGLFlatten} from './ops/flatten'; import {WebGLGather} from './ops/gather'; @@ -105,4 +106,5 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray = [ ['Upsample', '', '7-8', () => new WebGLUpsample()], ['Unsqueeze', '', '1+', () => new WebGLUnsqueeze()], ['Xor', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor())], + ['Einsum', '', '1+', () => new WebGLEinsum()], ]; diff --git a/lib/backends/webgl/ops/einsum.ts b/lib/backends/webgl/ops/einsum.ts new file mode 100644 index 00000000..233bd451 --- /dev/null +++ b/lib/backends/webgl/ops/einsum.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Einsum} from '../../../ops/einsum'; +import {Tensor} from '../../../tensor'; +import {WebGLInferenceHandler} from '../inference-handler'; +import {ProgramInfo, RunData, WebGLOperator} from '../types'; + +import {ShapeUtil} from './../../../util'; + +const samplerNames = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'.split(''); + +export class WebGLEinsum extends Einsum implements WebGLOperator { + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { + return inferenceHandler.run(this, inputs); + } + + createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { + const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs); + + const sumDims = []; + const sumDimSizes = []; + for (let i = 0; i < sizes.length; i++) { + if (outputIndices.indexOf(i) === -1) { + sumDims.push(i); + sumDimSizes.push(sizes[i]); + } + } + const sumSize = ShapeUtil.size(sumDimSizes); + + let rank = outputShape.length; + // Webgl doesnt like 0 length arrays + if (rank === 0) { + rank = 1; + } + + const initIndex1 = outputIndices.map((x, i) => `index[${x}] = indices[${i}];`).join('\n'); + const initIndex2 = sumDims.map(x => `index[${x}] = 0;`).join('\n'); + + const findInputValues = inputs.map((_, i) => this.buildFindInputValueScript(i, inputIndices[i])).join('\n'); + + const incrementIndex = this.buildIncrementIndexScript(sumDims, sumDimSizes); + + const shaderSource = ` + float process(int indices[${rank}]) { + float value = 0.0; + + int index[${sizes.length}]; + ${initIndex1} + ${initIndex2} + + int i = 0; + while(i < ${sumSize}) { + float add = 1.0; + + ${findInputValues} + + value += add; + + ${incrementIndex} + i++; + } + + return value; + }`; + const inputLayouts = inputs.map(t => inferenceHandler.getOrCreateTextureLayout(t)); + return { + inputLayouts, + outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape), + samplers: samplerNames.slice(0, inputs.length), + shaderSource, + }; + } + + buildFindInputValueScript(i: number, indices: number[]): string { + const initInputIndex = indices.map((ix, indiceNum) => `input${i}Index[${indiceNum}] = index[${ix}];`).join('\n'); + + const script = `int input${i}Index[${indices.length}]; + ${initInputIndex} + add *= _${samplerNames[i]}(input${i}Index);`; + + return script; + } + + buildIncrementIndexScript(sumDims: number[], sumDimSizes: number[]): string { + let script = ''; + for (let i = 0; i < sumDims.length; i++) { + script += ` + index[${sumDims[i]}] += 1; + if (index[${sumDims[i]}] >= ${sumDimSizes[i]}) { + index[${sumDims[i]}] = 0; + `; + } + for (let i = 0; i < sumDims.length; i++) { + script += '}\n'; + } + + return script; + } + + createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { + const inputTDs = inputs.map((v, i) => inferenceHandler.getOrCreateTextureData(v, programInfo.inputLayouts[i])); + return { + inputTextureDatas: inputTDs, + outputTextureData: + inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type), + uniformData: {} + }; + } +} diff --git a/lib/ops/einsum.ts b/lib/ops/einsum.ts index b1d4dbe9..3da2d7c9 100644 --- a/lib/ops/einsum.ts +++ b/lib/ops/einsum.ts @@ -9,6 +9,39 @@ import {Tensor} from '../tensor'; export abstract class Einsum implements Operator { abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + prepareRun(inputs: Tensor[]) { + const dimensionSizeMap: {[name: string]: number} = {}; + this.matchInputs(inputs, dimensionSizeMap); + const outputShape = this.calculateOutputSize(dimensionSizeMap); + + let i = 0; + const sizes = []; + const nameToId: {[name: string]: number} = {}; + const idToName: {[id: number]: string} = {}; + + for (const name in dimensionSizeMap) { + sizes.push(dimensionSizeMap[name]); + nameToId[name] = i; + idToName[i] = name; + i++; + } + + const outputIndices: number[] = []; + const inputIndices: number[][] = []; + for (const outputName of this.outputNames) { + outputIndices.push(nameToId[outputName]); + } + for (let i = 0; i < this.inputs.length; i++) { + const indices = []; + for (const inputName of this.inputNames[i]) { + indices.push(nameToId[inputName]); + } + inputIndices.push(indices); + } + + return {outputShape, sizes, outputIndices, inputIndices}; + } + initialize(attributes: Attribute): void { this.equation = attributes.getString('equation'); const split = this.equation.split('->'); From d3be4bb9bda645dd042e5579081d533339a3d1b7 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 20:01:31 +0200 Subject: [PATCH 08/10] Update operator documentation --- docs/operators.md | 2 +- lib/ops/einsum.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/operators.md b/docs/operators.md index 66ecab43..5d848ae3 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -43,7 +43,7 @@ _This file is automatically generated from the def files via [this script](/tool | [Div](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | | [Dropout](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Dropout) | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [DynamicQuantizeLinear](https://github.com/onnx/onnx/blob/master/docs/Operators.md#DynamicQuantizeLinear) | | | | -| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | | +| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | | [Elu](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Elu) | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [Equal](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal) | | | [7-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-7), [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-11) | | [Erf](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Erf) | | | | diff --git a/lib/ops/einsum.ts b/lib/ops/einsum.ts index 3da2d7c9..f9d352b9 100644 --- a/lib/ops/einsum.ts +++ b/lib/ops/einsum.ts @@ -129,7 +129,7 @@ export abstract class Einsum implements Operator { return false; } - const types = inputs.map(v => v.type) + const types = inputs.map(v => v.type); if (types.find(v => v !== types[0]) !== undefined) { return false; } From 78cddfc76fe341595575e2638dc72dd5d134bcda Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Mon, 7 Sep 2020 20:41:57 +0200 Subject: [PATCH 09/10] Add einsum.jsonc to test suite --- test/test-suite-whitelist.jsonc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test-suite-whitelist.jsonc b/test/test-suite-whitelist.jsonc index 64fc0cef..85554f27 100644 --- a/test/test-suite-whitelist.jsonc +++ b/test/test-suite-whitelist.jsonc @@ -265,7 +265,8 @@ "softmax.jsonc", "tan.jsonc", "transpose.jsonc", - "xor.jsonc" + "xor.jsonc", + "einsum.jsonc" ] }, "webgl": { @@ -527,7 +528,8 @@ "sub.jsonc", "tan.jsonc", "transpose.jsonc", - "xor.jsonc" + "xor.jsonc", + "einsum.jsonc" ] }, "wasm": { @@ -639,7 +641,8 @@ "and.jsonc", "or.jsonc", "xor.jsonc", - "matmul.jsonc" + "matmul.jsonc", + "einsum.jsonc" ] } } From 2f04d18e6e3a88e13ba6a928d8ce8a403b332fb8 Mon Sep 17 00:00:00 2001 From: Frithjof Winkelmann Date: Tue, 8 Sep 2020 09:52:03 +0200 Subject: [PATCH 10/10] Add tests for einsum --- test/test-suite-whitelist.jsonc | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/test/test-suite-whitelist.jsonc b/test/test-suite-whitelist.jsonc index 018d18ef..93df3718 100644 --- a/test/test-suite-whitelist.jsonc +++ b/test/test-suite-whitelist.jsonc @@ -225,7 +225,11 @@ "test_xor_bcast4v4d", "test_xor2d", "test_xor3d", - "test_xor4d" + "test_xor4d", + "v12/test_einsum_batch_matmul", + "v12/test_einsum_inner_prod", + "v12/test_einsum_sum", + "v12/test_einsum_transpose" ], "ops": [ "abs.jsonc", @@ -486,7 +490,11 @@ "test_xor_bcast4v4d", "test_xor2d", "test_xor3d", - "test_xor4d" + "test_xor4d", + "v12/test_einsum_batch_matmul", + "v12/test_einsum_inner_prod", + "v12/test_einsum_sum", + "v12/test_einsum_transpose" ], "ops": [ "abs.jsonc", @@ -626,7 +634,11 @@ "test_globalmaxpool_precomputed", "test_globalmaxpool", "test_instancenorm_epsilon", - "test_instancenorm_example" + "test_instancenorm_example", + "v12/test_einsum_batch_matmul", + "v12/test_einsum_inner_prod", + "v12/test_einsum_sum", + "v12/test_einsum_transpose" ], "ops": [ // Check in op tests that have native Wasm implementations