forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
/
01_plot_selu.py
38 lines (28 loc) · 916 Bytes
/
01_plot_selu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Generating a FunctionProto
==========================
The example below shows how we can define Selu as a function in onnxscript.
"""
# %%
# First, import the ONNX opset used to define the function.
from onnxscript import opset15 as op
from onnxscript import script
# %%
# Next, define Selu as an ONNXScript function.
@script()
def Selu(X, alpha: float, gamma: float):
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(zero >= X, neg, pos)
# %%
# We can convert the ONNXScript function to an ONNX function (FunctionProto) as below:
onnx_fun = Selu.to_function_proto()
# %%
# Let's see what the translated function looks like:
import onnx # noqa: E402
print(onnx.printer.to_text(onnx_fun))