diff --git a/lib/backends/cpu/ops-resolve.ts b/lib/backends/cpu/ops-resolve.ts index d8f9e6b0..9310d144 100644 --- a/lib/backends/cpu/ops-resolve.ts +++ b/lib/backends/cpu/ops-resolve.ts @@ -49,6 +49,8 @@ function createOperator(node: Graph.Node, domain: string, version: number): Oper return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.ceil); case 'Cos': return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.cos); + case 'Clip': + return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.clip); case 'Sin': return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.sin); case 'Tan': diff --git a/lib/backends/cpu/ops/unary-op.ts b/lib/backends/cpu/ops/unary-op.ts index ae58a5cc..dc9bbee7 100644 --- a/lib/backends/cpu/ops/unary-op.ts +++ b/lib/backends/cpu/ops/unary-op.ts @@ -67,6 +67,15 @@ export function cos(input: Tensor.NumberType, output: Tensor.NumberType, attribu } } +export function clip(input: Tensor.NumberType, output: Tensor.NumberType, attributes: Attribute) { + const min = attributes.getFloat('min', -3.4028234663852886e+38); + const max = attributes.getFloat('max', 3.4028234663852886e+38); + for (let i = 0; i < input.length; i++) { + const value = input[i]; + output[i] = (value < min) ? min : (value > max) ? max : value; + } +} + export function sin(input: Tensor.NumberType, output: Tensor.NumberType, attributes: Attribute) { for (let i = 0; i < input.length; i++) { output[i] = Math.sin(input[i]); diff --git a/lib/backends/wasm/ops/clip.ts b/lib/backends/wasm/ops/clip.ts new file mode 100644 index 00000000..210360b3 --- /dev/null +++ b/lib/backends/wasm/ops/clip.ts @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Clip} from '../../../ops/clip'; +import {Tensor} from '../../../tensor'; +import {WasmBinding} from '../../../wasm-binding'; +import {WasmInferenceHandler} from '../inference-handler'; + +export class WasmClip extends Clip { + run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { + const result = new Tensor(inputs[0].dims, inputs[0].type); + const size = result.floatData.length; + if (inputs[0].type === 'float32') { + WasmBinding.getInstance().ccall( + '_clip_f32', [inputs[0].floatData, 'float32ptr'], [result.floatData, 'float32ptr', 'out'], [size, 'int32'], + [this.min, 'float32'], [this.max, 'float32']); + } + // Expand for differnt types supported for this specific kernel of Clip + else { + throw new Error(`Unsupported input type for Clip operator.`); + } + return [result]; + } + + // overriding the checkInputTypes() in the base class because Wasm backend has special type limitations + checkInputTypes(inputs: Tensor[]): boolean { + // currently Wasm backend only supports 'float32' input type + if (inputs[0].type !== 'float32') { + return false; + } + + return true; + } +} diff --git a/lib/backends/wasm/session-handler.ts b/lib/backends/wasm/session-handler.ts index 0aaebe87..7eb10274 100644 --- a/lib/backends/wasm/session-handler.ts +++ b/lib/backends/wasm/session-handler.ts @@ -6,9 +6,11 @@ import {Graph} from '../../graph'; import {Operator} from '../../operators'; import {Session} from '../../session'; import {resolve} from '../cpu/ops-resolve'; + import {WasmInferenceHandler} from './inference-handler'; import {WasmBatchNormalization} from './ops/batch-normalization'; import {WasmBinaryOp} from './ops/binary-op'; +import {WasmClip} from './ops/clip'; import {WasmConv} from './ops/conv'; import {WasmGemm} from './ops/gemm'; import {WasmInstanceNormalization} from './ops/instance-normalization'; @@ -55,6 +57,8 @@ export class WasmSessionHandler implements SessionHandler { // Misc ops case 'Conv': return new WasmConv(); + case 'Clip': + return new WasmClip(); case 'BatchNormalization': return new WasmBatchNormalization(); case 'Gemm': diff --git a/lib/backends/webgl/ops/clip.ts b/lib/backends/webgl/ops/clip.ts new file mode 100644 index 00000000..9fba979e --- /dev/null +++ b/lib/backends/webgl/ops/clip.ts @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Clip} from '../../../ops/clip'; +import {Tensor} from '../../../tensor'; +import {WebGLInferenceHandler} from '../inference-handler'; +import {ProgramInfo} from '../program-info'; +import {RunData} from '../program-manager'; +import {WebGLOperator} from '../webgl-operator'; +import {WebGLOperatorHelper} from '../webgl-operator-utils'; + +export class WebGLClip extends Clip implements WebGLOperator { + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { + return WebGLOperatorHelper.run(this, inferenceHandler, inputs); + } + createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { + const outputShape = inputs[0].dims.slice(); + const shaderSource = ` + const float min = float(${this.min}); + const float max = float(${this.max}); + uniform sampler2D A; + void main() { + float v = texture2D(A, TexCoords).r; + gl_FragColor = vec4(clamp(v, min, max)); + } + `; + return { + hasMain: true, + inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])], + outputLayout: handler.createBasicTextureLayout(outputShape), + shaderSource, + }; + } + createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { + const inputTDs = [handler.getOrCreate(inputs[0], programInfo.inputLayouts[0])]; + return { + inputTextureDatas: inputTDs, + outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType), + uniformData: {} + }; + } +} diff --git a/lib/backends/webgl/session-handler.ts b/lib/backends/webgl/session-handler.ts index 10d2e260..67cd79f3 100644 --- a/lib/backends/webgl/session-handler.ts +++ b/lib/backends/webgl/session-handler.ts @@ -12,6 +12,7 @@ import {SessionHandler} from './../../backend'; import {WebGLInferenceHandler} from './inference-handler'; import {WebGLBatchNormalization} from './ops/batch-normalization'; import * as binaryOps from './ops/binary-op'; +import {WebGLClip} from './ops/clip'; import {WebGLConcat} from './ops/concat'; import {WebGLConv} from './ops/conv'; import {WebGLDropout} from './ops/dropout'; @@ -111,6 +112,8 @@ export class WebGLSessionHandler implements SessionHandler { return new WebGLBatchNormalization(); case 'Ceil': return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil()); + case 'Clip': + return new WebGLClip(); case 'Cos': return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos()); case 'Concat': diff --git a/lib/ops/clip.ts b/lib/ops/clip.ts new file mode 100644 index 00000000..5ec6213d --- /dev/null +++ b/lib/ops/clip.ts @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Attribute} from '../attribute'; +import {InferenceHandler} from '../backend'; +import {Operator} from '../operators'; +import {Tensor} from '../tensor'; + +export abstract class Clip implements Operator { + abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + + initialize(attributes: Attribute): void { + this.min = attributes.getFloat('min', -3.4028234663852886e+38); + this.max = attributes.getFloat('max', 3.4028234663852886e+38); + } + + checkInputs(inputs: Tensor[]): boolean { + if (!inputs || inputs.length !== 1) { + return false; + } + + return this.checkInputTypes(inputs); + } + + protected checkInputTypes(inputs: Tensor[]): boolean { + if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') { + return false; + } + + return true; + } + + protected min: number; + protected max: number; +} diff --git a/src/wasm-build-config.json b/src/wasm-build-config.json index 50e516a4..95a99ee8 100644 --- a/src/wasm-build-config.json +++ b/src/wasm-build-config.json @@ -19,6 +19,7 @@ "_gemm_f32", "_matmul_f32", "_batch_normalization_f32", + "_clip_f32", "_instance_normalization_f32", "_sum_f32", "_softmax_f32" diff --git a/src/wasm-ops/clip.cpp b/src/wasm-ops/clip.cpp new file mode 100644 index 00000000..9cd6aef4 --- /dev/null +++ b/src/wasm-ops/clip.cpp @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "clip.h" + +// Wasm interop methods +void clip_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + const float *input = PARAM_FLOAT_PTR(data, dataIndex[1]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[2]); + const int32_t length = PARAM_INT32(data, dataIndex[3]); + const float min = PARAM_FLOAT(data, dataIndex[4]); + const float max = PARAM_FLOAT(data, dataIndex[5]); + clip_imp(input, output, length, min, max); +} diff --git a/src/wasm-ops/clip.h b/src/wasm-ops/clip.h new file mode 100644 index 00000000..a5ae5c8b --- /dev/null +++ b/src/wasm-ops/clip.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "common.h" + +extern "C" { +void clip_f32(void *); +// Expand for other supported data types for `clip` +} + +// Core implementation of the op +template +void clip_imp(const T *input, T *output, const int32_t length, const float min, + const float max) { + for (size_t i = 0; i < length; ++i) { + const auto &val = input[i]; + output[i] = (val < min) ? min : (val > max) ? max : val; + } +} diff --git a/test/unittest-whitelist.jsonc b/test/unittest-whitelist.jsonc index b1b283a1..ffe6c6d0 100644 --- a/test/unittest-whitelist.jsonc +++ b/test/unittest-whitelist.jsonc @@ -44,6 +44,14 @@ "test_basic_conv_without_padding", "test_batchnorm_epsilon", "test_batchnorm_example", + "test_clip_splitbounds", + "test_clip_outbounds", + "test_clip_inbounds", + "test_clip_example", + "test_clip_default_min", + "test_clip_default_max", + "test_clip_default_inbounds", + "test_clip", "test_concat_1d_axis_0", "test_concat_2d_axis_0", "test_concat_2d_axis_1", @@ -277,6 +285,14 @@ "test_basic_conv_without_padding", "test_batchnorm_epsilon", "test_batchnorm_example", + "test_clip_splitbounds", + "test_clip_outbounds", + "test_clip_inbounds", + "test_clip_example", + "test_clip_default_min", + "test_clip_default_max", + "test_clip_default_inbounds", + "test_clip", "test_concat_1d_axis_0", "test_concat_2d_axis_0", "test_concat_2d_axis_1", @@ -534,6 +550,14 @@ "test_basic_conv_without_padding", "test_batchnorm_epsilon", "test_batchnorm_example", + "test_clip_splitbounds", + "test_clip_outbounds", + "test_clip_inbounds", + "test_clip_example", + "test_clip_default_min", + "test_clip_default_max", + "test_clip_default_inbounds", + "test_clip", "test_conv_with_strides_and_asymmetric_padding", "test_conv_with_strides_no_padding", "test_conv_with_strides_padding",