diff --git a/lib/backends/cpu/ops-resolve.ts b/lib/backends/cpu/ops-resolve.ts index 8b50e1d5..71793ffb 100644 --- a/lib/backends/cpu/ops-resolve.ts +++ b/lib/backends/cpu/ops-resolve.ts @@ -10,6 +10,7 @@ import {CpuBinaryOp} from './ops/binary-op'; import {CpuConcat} from './ops/concat'; import {CpuConv} from './ops/conv'; import {CpuDropout} from './ops/dropout'; +import {CpuFlatten} from './ops/flatten'; import {CpuGather} from './ops/gather'; import {CpuGemm} from './ops/gemm'; import {CpuImageScaler} from './ops/image-scaler'; @@ -98,6 +99,8 @@ function createOperator(node: Graph.Node, domain: string, version: number): Oper return new CpuConv(); case 'Dropout': return new CpuDropout(); + case 'Flatten': + return new CpuFlatten(); case 'Gemm': return new CpuGemm(); case 'ImageScaler': diff --git a/lib/backends/cpu/ops/flatten.ts b/lib/backends/cpu/ops/flatten.ts new file mode 100644 index 00000000..80f45eb9 --- /dev/null +++ b/lib/backends/cpu/ops/flatten.ts @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Flatten} from '../../../ops/flatten'; +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {CpuInferenceHandler} from '../inference-handler'; + +export class CpuFlatten extends Flatten { + run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] { + const output = flatten(inputs[0], this.axis); + return [output]; + } +} + +export function flatten(x: Tensor, axis: number): Tensor { + const outputDims = ShapeUtil.flattenShape(x.dims, axis); + const output = new Tensor(outputDims, x.type); + + const X = x.numberData; + const Y = output.numberData; + + Y.set(X); + + return output; +} diff --git a/lib/backends/webgl/ops/flatten.ts b/lib/backends/webgl/ops/flatten.ts new file mode 100644 index 00000000..59777fe8 --- /dev/null +++ b/lib/backends/webgl/ops/flatten.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Flatten} from '../../../ops/flatten'; +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {WebGLInferenceHandler} from '../inference-handler'; + +import {reshape} from './reshape'; + +export class WebGLFlatten extends Flatten { + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { + const outputDims = ShapeUtil.flattenShape(inputs[0].dims, this.axis); + + return [reshape(inferenceHandler, inputs[0], outputDims)]; + } +} diff --git a/lib/backends/webgl/session-handler.ts b/lib/backends/webgl/session-handler.ts index 44081fcb..b4097233 100644 --- a/lib/backends/webgl/session-handler.ts +++ b/lib/backends/webgl/session-handler.ts @@ -15,6 +15,7 @@ import * as binaryOps from './ops/binary-op'; import {WebGLConcat} from './ops/concat'; import {WebGLConv} from './ops/conv'; import {WebGLDropout} from './ops/dropout'; +import {WebGLFlatten} from './ops/flatten'; import {WebGLGather} from './ops/gather'; import {WebGLGemm} from './ops/gemm'; import {WebGLImageScaler} from './ops/image-scaler'; @@ -123,6 +124,8 @@ export class WebGLSessionHandler implements SessionHandler { return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), undefined, 'bool'); case 'Exp': return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslExp()); + case 'Flatten': + return new WebGLFlatten(); case 'Floor': return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor()); case 'Gather': diff --git a/lib/ops/flatten.ts b/lib/ops/flatten.ts new file mode 100644 index 00000000..2a7e8cff --- /dev/null +++ b/lib/ops/flatten.ts @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Attribute} from '../attribute'; +import {InferenceHandler} from '../backend'; +import {Operator} from '../operators'; +import {Tensor} from '../tensor'; + +export abstract class Flatten implements Operator { + abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + + initialize(attributes: Attribute): void { + this.axis = attributes.getInt('axis', 1); // default axis is 1 + } + + checkInputs(inputs: Tensor[]): boolean { + if (!inputs || inputs.length !== 1) { + return false; + } + + if (inputs[0].dims.length === 0) { + return false; // scalar tensor is not supported + } + + if (this.axis < 0 || this.axis > inputs[0].dims.length) { + return false; + } + + return this.checkInputTypes(inputs); + } + + protected checkInputTypes(inputs: Tensor[]): boolean { + // TODO: Support string type + if (inputs[0].type === 'string') { + return false; + } + + return true; + } + + protected axis: number; +} diff --git a/lib/util.ts b/lib/util.ts index 07dde8a4..75208c6c 100644 --- a/lib/util.ts +++ b/lib/util.ts @@ -667,6 +667,19 @@ export class ShapeUtil { return size; } + /** + * Determines the shape of output tensor y = flatten(x, axis) + * @param dims - shape of input tensor + * @param axis - flatten axis + */ + static flattenShape(dims: ReadonlyArray, axis: number): ReadonlyArray { + const total = dims.reduce((x, y) => x * y, 1); + const right = dims.slice(axis).reduce((x, y) => x * y, 1); + const outputDims = [total / right, right]; + + return outputDims; + } + /** * Determines the shape of output tensor y = squeeze(x, axes) * @param dims - shape of input tensor diff --git a/test/unittest-whitelist.jsonc b/test/unittest-whitelist.jsonc index 19d6884f..25dec7ec 100644 --- a/test/unittest-whitelist.jsonc +++ b/test/unittest-whitelist.jsonc @@ -61,6 +61,11 @@ "test_div", "test_dropout_default", "test_dropout_random", + "test_flatten_axis0", + "test_flatten_axis1", + "test_flatten_axis2", + "test_flatten_axis3", + "test_flatten_default_axis", "test_gather_0", "test_gather_1", "test_gemm_broadcast", @@ -286,6 +291,11 @@ "test_div", "test_dropout_default", "test_dropout_random", + "test_flatten_axis0", + "test_flatten_axis1", + "test_flatten_axis2", + "test_flatten_axis3", + "test_flatten_default_axis", "test_gather_0", "test_gather_1", "test_gemm_nobroadcast",