This repository has been archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
operator: Support Clip op for cpu, wasm, and webgl backends (#107)
* Initial commit * Clip for WebGL * Support Clip for WASM backend * Formatting * More changes * Use GLSL clamp for core clip operation * PR feedback * More changes
- Loading branch information
1 parent
988f891
commit 0432a35
Showing
11 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {Clip} from '../../../ops/clip'; | ||
import {Tensor} from '../../../tensor'; | ||
import {WasmBinding} from '../../../wasm-binding'; | ||
import {WasmInferenceHandler} from '../inference-handler'; | ||
|
||
export class WasmClip extends Clip { | ||
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
const result = new Tensor(inputs[0].dims, inputs[0].type); | ||
const size = result.floatData.length; | ||
if (inputs[0].type === 'float32') { | ||
WasmBinding.getInstance().ccall( | ||
'_clip_f32', [inputs[0].floatData, 'float32ptr'], [result.floatData, 'float32ptr', 'out'], [size, 'int32'], | ||
[this.min, 'float32'], [this.max, 'float32']); | ||
} | ||
// Expand for differnt types supported for this specific kernel of Clip | ||
else { | ||
throw new Error(`Unsupported input type for Clip operator.`); | ||
} | ||
return [result]; | ||
} | ||
|
||
// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations | ||
checkInputTypes(inputs: Tensor[]): boolean { | ||
// currently Wasm backend only supports 'float32' input type | ||
if (inputs[0].type !== 'float32') { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {Clip} from '../../../ops/clip'; | ||
import {Tensor} from '../../../tensor'; | ||
import {WebGLInferenceHandler} from '../inference-handler'; | ||
import {ProgramInfo} from '../program-info'; | ||
import {RunData} from '../program-manager'; | ||
import {WebGLOperator} from '../webgl-operator'; | ||
import {WebGLOperatorHelper} from '../webgl-operator-utils'; | ||
|
||
export class WebGLClip extends Clip implements WebGLOperator { | ||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
return WebGLOperatorHelper.run(this, inferenceHandler, inputs); | ||
} | ||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { | ||
const outputShape = inputs[0].dims.slice(); | ||
const shaderSource = ` | ||
const float min = float(${this.min}); | ||
const float max = float(${this.max}); | ||
uniform sampler2D A; | ||
void main() { | ||
float v = texture2D(A, TexCoords).r; | ||
gl_FragColor = vec4(clamp(v, min, max)); | ||
} | ||
`; | ||
return { | ||
hasMain: true, | ||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])], | ||
outputLayout: handler.createBasicTextureLayout(outputShape), | ||
shaderSource, | ||
}; | ||
} | ||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { | ||
const inputTDs = [handler.getOrCreate(inputs[0], programInfo.inputLayouts[0])]; | ||
return { | ||
inputTextureDatas: inputTDs, | ||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType), | ||
uniformData: {} | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {Attribute} from '../attribute'; | ||
import {InferenceHandler} from '../backend'; | ||
import {Operator} from '../operators'; | ||
import {Tensor} from '../tensor'; | ||
|
||
export abstract class Clip implements Operator { | ||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>; | ||
|
||
initialize(attributes: Attribute): void { | ||
this.min = attributes.getFloat('min', -3.4028234663852886e+38); | ||
this.max = attributes.getFloat('max', 3.4028234663852886e+38); | ||
} | ||
|
||
checkInputs(inputs: Tensor[]): boolean { | ||
if (!inputs || inputs.length !== 1) { | ||
return false; | ||
} | ||
|
||
return this.checkInputTypes(inputs); | ||
} | ||
|
||
protected checkInputTypes(inputs: Tensor[]): boolean { | ||
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
protected min: number; | ||
protected max: number; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
#include "clip.h" | ||
|
||
// Wasm interop methods | ||
void clip_f32(void *data) { | ||
uint32_t *dataIndex = static_cast<uint32_t *>(data); | ||
const float *input = PARAM_FLOAT_PTR(data, dataIndex[1]); | ||
float *output = PARAM_FLOAT_PTR(data, dataIndex[2]); | ||
const int32_t length = PARAM_INT32(data, dataIndex[3]); | ||
const float min = PARAM_FLOAT(data, dataIndex[4]); | ||
const float max = PARAM_FLOAT(data, dataIndex[5]); | ||
clip_imp<float>(input, output, length, min, max); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
#pragma once | ||
|
||
#include "common.h" | ||
|
||
extern "C" { | ||
void clip_f32(void *); | ||
// Expand for other supported data types for `clip` | ||
} | ||
|
||
// Core implementation of the op | ||
template <typename T> | ||
void clip_imp(const T *input, T *output, const int32_t length, const float min, | ||
const float max) { | ||
for (size_t i = 0; i < length; ++i) { | ||
const auto &val = input[i]; | ||
output[i] = (val < min) ? min : (val > max) ? max : val; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters