-
Notifications
You must be signed in to change notification settings - Fork 96
/
hello_world.py
77 lines (61 loc) · 2.41 KB
/
hello_world.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import functools
import numpy as np
import matplotlib.pyplot as plt
import evidential_deep_learning as edl
import tensorflow as tf
def main():
# Create some training and testing data
x_train, y_train = my_data(-4, 4, 1000)
x_test, y_test = my_data(-7, 7, 1000, train=False)
# Define our model with an evidential output
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(64, activation="relu"),
edl.layers.DenseNormalGamma(1),
])
# Custom loss function to handle the custom regularizer coefficient
def EvidentialRegressionLoss(true, pred):
return edl.losses.EvidentialRegression(true, pred, coeff=1e-2)
# Compile and fit the model!
model.compile(
optimizer=tf.keras.optimizers.Adam(5e-4),
loss=EvidentialRegressionLoss)
model.fit(x_train, y_train, batch_size=100, epochs=500)
# Predict and plot using the trained model
y_pred = model(x_test)
plot_predictions(x_train, y_train, x_test, y_test, y_pred)
# Done!!
#### Helper functions ####
def my_data(x_min, x_max, n, train=True):
x = np.linspace(x_min, x_max, n)
x = np.expand_dims(x, -1).astype(np.float32)
sigma = 3 * np.ones_like(x) if train else np.zeros_like(x)
y = x**3 + np.random.normal(0, sigma).astype(np.float32)
return x, y
def plot_predictions(x_train, y_train, x_test, y_test, y_pred, n_stds=4, kk=0):
x_test = x_test[:, 0]
mu, v, alpha, beta = tf.split(y_pred, 4, axis=-1)
mu = mu[:, 0]
var = np.sqrt(beta / (v * (alpha - 1)))
var = np.minimum(var, 1e3)[:, 0] # for visualization
plt.figure(figsize=(5, 3), dpi=200)
plt.scatter(x_train, y_train, s=1., c='#463c3c', zorder=0, label="Train")
plt.plot(x_test, y_test, 'r--', zorder=2, label="True")
plt.plot(x_test, mu, color='#007cab', zorder=3, label="Pred")
plt.plot([-4, -4], [-150, 150], 'k--', alpha=0.4, zorder=0)
plt.plot([+4, +4], [-150, 150], 'k--', alpha=0.4, zorder=0)
for k in np.linspace(0, n_stds, 4):
plt.fill_between(
x_test, (mu - k * var), (mu + k * var),
alpha=0.3,
edgecolor=None,
facecolor='#00aeef',
linewidth=0,
zorder=1,
label="Unc." if k == 0 else None)
plt.gca().set_ylim(-150, 150)
plt.gca().set_xlim(-7, 7)
plt.legend(loc="upper left")
plt.show()
if __name__ == "__main__":
main()