diff --git a/lib/backends/cpu/op-resolve-rules.ts b/lib/backends/cpu/op-resolve-rules.ts index 972ee3b6..9aa182b9 100644 --- a/lib/backends/cpu/op-resolve-rules.ts +++ b/lib/backends/cpu/op-resolve-rules.ts @@ -58,6 +58,7 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray = [ ['Div', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 / e2))], ['Dropout', '', '7+', () => new CpuDropout()], ['Elu', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.elu, unaryOps.eluInitializer)], + ['Selu', '', '', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.selu, unaryOps.seluInitializer)], ['Exp', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.exp)], ['Expand', '', '8+', () => new CpuExpand()], ['Flatten', '', '1+', () => new CpuFlatten()], diff --git a/lib/backends/cpu/ops/unary-op.ts b/lib/backends/cpu/ops/unary-op.ts index 056f0308..1c0c97a5 100644 --- a/lib/backends/cpu/ops/unary-op.ts +++ b/lib/backends/cpu/ops/unary-op.ts @@ -134,6 +134,23 @@ export function elu(input: Tensor.NumberType, output: Tensor.NumberType, attribu } } + +export function seluInitializer(attributes: Attribute) { + return { + alpha: attributes.getFloat('alpha', 1.67326319217681884765625), + gamma: attributes.getFloat('gamma', 1.05070102214813232421875) + }; +} + +export function selu(input: Tensor.NumberType, output: Tensor.NumberType, attributes: {alpha: number, gamma: number}) { + const alpha = attributes.alpha; + const gamma = attributes.gamma; + for (let i = 0; i < input.length; i++) { + const value = input[i]; + output[i] = value >= 0 ? gamma * (alpha * Math.exp(value) - alpha) : gamma * value; + } +} + export function exp(input: Tensor.NumberType, output: Tensor.NumberType) { for (let i = 0; i < input.length; i++) { output[i] = Math.exp(input[i]); diff --git a/lib/ops/selu.ts b/lib/ops/selu.ts new file mode 100644 index 00000000..26b7fc5a --- /dev/null +++ b/lib/ops/selu.ts @@ -0,0 +1,32 @@ +import {Attribute} from '../attribute'; +import {InferenceHandler} from '../backend'; +import {Operator} from '../operators'; +import {Tensor} from '../tensor'; + +export abstract class Selu implements Operator { + abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + + initialize(attributes: Attribute): void { + this.alpha = attributes.getFloat('alpha', 1.67326319217681884765625); + this.gamma = attributes.getFloat('gamma', 1.05070102214813232421875); + } + + checkInputs(inputs: Tensor[]): boolean { + if (!inputs || inputs.length !== 1) { + return false; + } + + return this.checkInputTypes(inputs); + } + + protected checkInputTypes(inputs: Tensor[]): boolean { + if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') { + return false; + } + + return true; + } + + protected alpha: number; + protected gamma: number; +}