Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
Add ELU activation function (#101)
Browse files Browse the repository at this point in the history
* Add ELU activation function

* Update ops-resolve.ts
  • Loading branch information
NTT123 authored and hariharans29 committed Mar 12, 2019
1 parent 7429f07 commit 3439e56
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/backends/cpu/ops-resolve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ function createOperator(node: Graph.Node, domain: string, version: number): Oper
// opLambda will be resolved when the op is initialized at which time it will have context of the attribute
// 'alpha'
return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.leakyRelu);
case 'Elu':
return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.elu);
case 'MatMul':
return new CpuMatMul();
case 'AveragePool':
Expand Down
8 changes: 8 additions & 0 deletions lib/backends/cpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ export function leakyRelu(input: Tensor.NumberType, output: Tensor.NumberType, a
}
}

export function elu(input: Tensor.NumberType, output: Tensor.NumberType, attributes: Attribute) {
const alpha = attributes.getFloat('alpha', 1.0);
for (let i = 0; i < input.length; i++) {
const value = input[i];
output[i] = value >= 0 ? value : alpha * (Math.exp(value) - 1.0);
}
}

export function log(input: Tensor.NumberType, output: Tensor.NumberType, attributes: Attribute) {
for (let i = 0; i < input.length; i++) {
output[i] = Math.log(input[i]);
Expand Down
40 changes: 40 additions & 0 deletions lib/backends/webgl/ops/elu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Elu} from '../../../ops/elu';
import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo} from '../program-info';
import {RunData} from '../program-manager';
import {WebGLOperator} from '../webgl-operator';
import {WebGLOperatorHelper} from '../webgl-operator-utils';

export class WebGLElu extends Elu implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return WebGLOperatorHelper.run(this, inferenceHandler, inputs);
}
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const outputShape = inputs[0].dims.slice();
const shaderSource = `
uniform sampler2D A;
void main() {
float v = texture2D(A, TexCoords).r;
gl_FragColor = vec4(v >= 0.0 ? v: (exp(v) - 1.0) * ${this.alpha.toExponential()}); /* float number format */
}
`;
return {
hasMain: true,
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
outputLayout: handler.createBasicTextureLayout(outputShape),
shaderSource,
};
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = [handler.getOrCreate(inputs[0], programInfo.inputLayouts[0])];
return {
inputTextureDatas: inputTDs,
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType),
uniformData: {}
};
}
}
3 changes: 3 additions & 0 deletions lib/backends/webgl/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import * as binaryOps from './ops/binary-op';
import {WebGLConcat} from './ops/concat';
import {WebGLConv} from './ops/conv';
import {WebGLDropout} from './ops/dropout';
import {WebGLElu} from './ops/elu';
import {WebGLFlatten} from './ops/flatten';
import {WebGLGather} from './ops/gather';
import {WebGLGemm} from './ops/gemm';
Expand Down Expand Up @@ -144,6 +145,8 @@ export class WebGLSessionHandler implements SessionHandler {
return new WebGLImageScaler();
case 'LeakyRelu':
return new WebGLLeakyRelu();
case 'Elu':
return new WebGLElu();
case 'Less':
return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool');
case 'Log':
Expand Down
33 changes: 33 additions & 0 deletions lib/ops/elu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

export abstract class Elu implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {
this.alpha = attributes.getFloat('alpha', 1.0);
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 1) {
return false;
}

return this.checkInputTypes(inputs);
}

protected checkInputTypes(inputs: Tensor[]): boolean {
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
return false;
}

return true;
}

protected alpha: number;
}
6 changes: 6 additions & 0 deletions test/unittest-whitelist.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
"test_div",
"test_dropout_default",
"test_dropout_random",
"test_elu_example",
"test_elu",
"test_elu_default",
"test_flatten_axis0",
"test_flatten_axis1",
"test_flatten_axis2",
Expand Down Expand Up @@ -291,6 +294,9 @@
"test_div",
"test_dropout_default",
"test_dropout_random",
"test_elu_example",
"test_elu",
"test_elu_default",
"test_flatten_axis0",
"test_flatten_axis1",
"test_flatten_axis2",
Expand Down

0 comments on commit 3439e56

Please sign in to comment.