Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
operator: Support Clip op for cpu, wasm, and webgl backends (#107)
Browse files Browse the repository at this point in the history
* Initial commit

* Clip for WebGL

* Support Clip for WASM backend

* Formatting

* More changes

* Use GLSL clamp for core clip operation

* PR feedback

* More changes
  • Loading branch information
hariharans29 authored Mar 20, 2019
1 parent 988f891 commit 0432a35
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/backends/cpu/ops-resolve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ function createOperator(node: Graph.Node, domain: string, version: number): Oper
return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.ceil);
case 'Cos':
return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.cos);
case 'Clip':
return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.clip);
case 'Sin':
return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.sin);
case 'Tan':
Expand Down
9 changes: 9 additions & 0 deletions lib/backends/cpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ export function cos(input: Tensor.NumberType, output: Tensor.NumberType, attribu
}
}

export function clip(input: Tensor.NumberType, output: Tensor.NumberType, attributes: Attribute) {
const min = attributes.getFloat('min', -3.4028234663852886e+38);
const max = attributes.getFloat('max', 3.4028234663852886e+38);
for (let i = 0; i < input.length; i++) {
const value = input[i];
output[i] = (value < min) ? min : (value > max) ? max : value;
}
}

export function sin(input: Tensor.NumberType, output: Tensor.NumberType, attributes: Attribute) {
for (let i = 0; i < input.length; i++) {
output[i] = Math.sin(input[i]);
Expand Down
34 changes: 34 additions & 0 deletions lib/backends/wasm/ops/clip.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Clip} from '../../../ops/clip';
import {Tensor} from '../../../tensor';
import {WasmBinding} from '../../../wasm-binding';
import {WasmInferenceHandler} from '../inference-handler';

export class WasmClip extends Clip {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const result = new Tensor(inputs[0].dims, inputs[0].type);
const size = result.floatData.length;
if (inputs[0].type === 'float32') {
WasmBinding.getInstance().ccall(
'_clip_f32', [inputs[0].floatData, 'float32ptr'], [result.floatData, 'float32ptr', 'out'], [size, 'int32'],
[this.min, 'float32'], [this.max, 'float32']);
}
// Expand for differnt types supported for this specific kernel of Clip
else {
throw new Error(`Unsupported input type for Clip operator.`);
}
return [result];
}

// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations
checkInputTypes(inputs: Tensor[]): boolean {
// currently Wasm backend only supports 'float32' input type
if (inputs[0].type !== 'float32') {
return false;
}

return true;
}
}
4 changes: 4 additions & 0 deletions lib/backends/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import {Graph} from '../../graph';
import {Operator} from '../../operators';
import {Session} from '../../session';
import {resolve} from '../cpu/ops-resolve';

import {WasmInferenceHandler} from './inference-handler';
import {WasmBatchNormalization} from './ops/batch-normalization';
import {WasmBinaryOp} from './ops/binary-op';
import {WasmClip} from './ops/clip';
import {WasmConv} from './ops/conv';
import {WasmGemm} from './ops/gemm';
import {WasmInstanceNormalization} from './ops/instance-normalization';
Expand Down Expand Up @@ -55,6 +57,8 @@ export class WasmSessionHandler implements SessionHandler {
// Misc ops
case 'Conv':
return new WasmConv();
case 'Clip':
return new WasmClip();
case 'BatchNormalization':
return new WasmBatchNormalization();
case 'Gemm':
Expand Down
42 changes: 42 additions & 0 deletions lib/backends/webgl/ops/clip.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Clip} from '../../../ops/clip';
import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo} from '../program-info';
import {RunData} from '../program-manager';
import {WebGLOperator} from '../webgl-operator';
import {WebGLOperatorHelper} from '../webgl-operator-utils';

export class WebGLClip extends Clip implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return WebGLOperatorHelper.run(this, inferenceHandler, inputs);
}
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const outputShape = inputs[0].dims.slice();
const shaderSource = `
const float min = float(${this.min});
const float max = float(${this.max});
uniform sampler2D A;
void main() {
float v = texture2D(A, TexCoords).r;
gl_FragColor = vec4(clamp(v, min, max));
}
`;
return {
hasMain: true,
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
outputLayout: handler.createBasicTextureLayout(outputShape),
shaderSource,
};
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = [handler.getOrCreate(inputs[0], programInfo.inputLayouts[0])];
return {
inputTextureDatas: inputTDs,
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType),
uniformData: {}
};
}
}
3 changes: 3 additions & 0 deletions lib/backends/webgl/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {SessionHandler} from './../../backend';
import {WebGLInferenceHandler} from './inference-handler';
import {WebGLBatchNormalization} from './ops/batch-normalization';
import * as binaryOps from './ops/binary-op';
import {WebGLClip} from './ops/clip';
import {WebGLConcat} from './ops/concat';
import {WebGLConv} from './ops/conv';
import {WebGLDropout} from './ops/dropout';
Expand Down Expand Up @@ -111,6 +112,8 @@ export class WebGLSessionHandler implements SessionHandler {
return new WebGLBatchNormalization();
case 'Ceil':
return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil());
case 'Clip':
return new WebGLClip();
case 'Cos':
return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos());
case 'Concat':
Expand Down
35 changes: 35 additions & 0 deletions lib/ops/clip.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

export abstract class Clip implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {
this.min = attributes.getFloat('min', -3.4028234663852886e+38);
this.max = attributes.getFloat('max', 3.4028234663852886e+38);
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 1) {
return false;
}

return this.checkInputTypes(inputs);
}

protected checkInputTypes(inputs: Tensor[]): boolean {
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
return false;
}

return true;
}

protected min: number;
protected max: number;
}
1 change: 1 addition & 0 deletions src/wasm-build-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"_gemm_f32",
"_matmul_f32",
"_batch_normalization_f32",
"_clip_f32",
"_instance_normalization_f32",
"_sum_f32",
"_softmax_f32"
Expand Down
15 changes: 15 additions & 0 deletions src/wasm-ops/clip.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "clip.h"

// Wasm interop methods
void clip_f32(void *data) {
uint32_t *dataIndex = static_cast<uint32_t *>(data);
const float *input = PARAM_FLOAT_PTR(data, dataIndex[1]);
float *output = PARAM_FLOAT_PTR(data, dataIndex[2]);
const int32_t length = PARAM_INT32(data, dataIndex[3]);
const float min = PARAM_FLOAT(data, dataIndex[4]);
const float max = PARAM_FLOAT(data, dataIndex[5]);
clip_imp<float>(input, output, length, min, max);
}
21 changes: 21 additions & 0 deletions src/wasm-ops/clip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once

#include "common.h"

extern "C" {
void clip_f32(void *);
// Expand for other supported data types for `clip`
}

// Core implementation of the op
template <typename T>
void clip_imp(const T *input, T *output, const int32_t length, const float min,
const float max) {
for (size_t i = 0; i < length; ++i) {
const auto &val = input[i];
output[i] = (val < min) ? min : (val > max) ? max : val;
}
}
24 changes: 24 additions & 0 deletions test/unittest-whitelist.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
"test_basic_conv_without_padding",
"test_batchnorm_epsilon",
"test_batchnorm_example",
"test_clip_splitbounds",
"test_clip_outbounds",
"test_clip_inbounds",
"test_clip_example",
"test_clip_default_min",
"test_clip_default_max",
"test_clip_default_inbounds",
"test_clip",
"test_concat_1d_axis_0",
"test_concat_2d_axis_0",
"test_concat_2d_axis_1",
Expand Down Expand Up @@ -277,6 +285,14 @@
"test_basic_conv_without_padding",
"test_batchnorm_epsilon",
"test_batchnorm_example",
"test_clip_splitbounds",
"test_clip_outbounds",
"test_clip_inbounds",
"test_clip_example",
"test_clip_default_min",
"test_clip_default_max",
"test_clip_default_inbounds",
"test_clip",
"test_concat_1d_axis_0",
"test_concat_2d_axis_0",
"test_concat_2d_axis_1",
Expand Down Expand Up @@ -534,6 +550,14 @@
"test_basic_conv_without_padding",
"test_batchnorm_epsilon",
"test_batchnorm_example",
"test_clip_splitbounds",
"test_clip_outbounds",
"test_clip_inbounds",
"test_clip_example",
"test_clip_default_min",
"test_clip_default_max",
"test_clip_default_inbounds",
"test_clip",
"test_conv_with_strides_and_asymmetric_padding",
"test_conv_with_strides_no_padding",
"test_conv_with_strides_padding",
Expand Down

0 comments on commit 0432a35

Please sign in to comment.