-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01_train_torch_chempy.py
179 lines (133 loc) · 5.98 KB
/
01_train_torch_chempy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from Chempy.parameter import ModelParameters
import torch
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
import time as t
import os
# ----- Load the data ---------------------------------------------------------------------------------------------------------------------------------------------
# --- Load in training data ---
path_training = os.path.dirname(__file__) + '/data/chempy_data/chempy_TNG_train_data.npz'
training_data = np.load(path_training, mmap_mode='r')
elements = training_data['elements']
train_x = training_data['params']
train_y = training_data['abundances']
# --- Load in the validation data ---
path_test = os.path.dirname(__file__) + '/data/chempy_data/chempy_TNG_val_data.npz'
val_data = np.load(path_test, mmap_mode='r')
val_x = val_data['params']
val_y = val_data['abundances']
# --- Clean the data ---
# Chempy sometimes returns zeros or infinite values, which need to removed
def clean_data(x, y):
# Remove all zeros from the training data
index = np.where((y == 0).all(axis=1))[0]
x = np.delete(x, index, axis=0)
y = np.delete(y, index, axis=0)
# Remove all infinite values from the training data
index = np.where(np.isfinite(y).all(axis=1))[0]
x = x[index]
y = y[index]
return x, y
train_x, train_y = clean_data(train_x, train_y)
val_x, val_y = clean_data(val_x, val_y)
# convert to torch tensors
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
val_x = torch.tensor(val_x, dtype=torch.float32)
val_y = torch.tensor(val_y, dtype=torch.float32)
# ----- Define the prior ------------------------------------------------------------------------------------------------------------------------------------------
a = ModelParameters()
labels = [a.to_optimize[i] for i in range(len(a.to_optimize))] + ['time']
priors = torch.tensor([[a.priors[opt][0], a.priors[opt][1]] for opt in a.to_optimize])
# ----- Define the model ------------------------------------------------------------------------------------------------------------------------------------------
class Model_Torch(torch.nn.Module):
def __init__(self):
super(Model_Torch, self).__init__()
self.l1 = torch.nn.Linear(train_x.shape[1], 100)
self.l2 = torch.nn.Linear(100, 40)
self.l3 = torch.nn.Linear(40, train_y.shape[1])
def forward(self, x):
x = torch.tanh(self.l1(x))
x = torch.tanh(self.l2(x))
x = self.l3(x)
return x
model = Model_Torch()
# ----- Train the model -------------------------------------------------------------------------------------------------------------------------------------------
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
# shuffle the data
index = np.arange(train_x.shape[0])
np.random.shuffle(index)
train_x = train_x[index]
train_y = train_y[index]
# --- Train the neural network ---
epochs = 20
batch_size = 64
ep_loss = []
start = t.time()
for epoch in range(epochs):
start_epoch = t.time()
train_loss = []
for i in range(0, train_x.shape[0], batch_size):
optimizer.zero_grad()
# Get the batch
x_batch = train_x[i:i+batch_size].requires_grad_(True)
y_batch = train_y[i:i+batch_size].requires_grad_(True)
# Forward pass
y_pred = model(x_batch)
# Compute Loss
loss = loss_fn(y_pred, y_batch)
train_loss.append(loss.item())
# Backward pass
loss.backward(retain_graph=True)
optimizer.step()
# Validation loss
y_pred = model(val_x)
val_loss = loss_fn(y_pred, val_y)
train_loss = np.array(train_loss).mean()
ep_loss.append([train_loss, val_loss.item()])
end_epoch = t.time()
epoch_time = end_epoch - start_epoch
print(f'Epoch {epoch+1}/{epochs} in {round(epoch_time,1)}s, Loss: {round(train_loss,6)} | Val Loss: {round(val_loss.item(),6)}')
print(f'Training finished | Total time: {round(end_epoch - start, 1)}s')
# ----- Plot the loss -----
ep_loss = np.array(ep_loss)
plt.plot(np.arange(epochs)+1, ep_loss[:,0], label='Training Loss')
plt.plot(np.arange(epochs)+1, ep_loss[:,1], label='Validation Loss')
plt.xlabel('Epoch', fontsize=15)
plt.ylabel('MSE Loss', fontsize=15)
plt.title('Training and Validation Loss', fontsize=20)
plt.legend()
plt.tight_layout()
plt.savefig("plots/loss_NN_simulator.png")
plt.clf()
# ----- Calculate the Absolute Percantage Error -----
ape = 100 * torch.abs((val_y - model(val_x)) / val_y).detach().numpy()
fig, (ax_box, ax_hist) = plt.subplots(2, sharex=True, gridspec_kw={"height_ratios": (.20, .80)})
ax_hist.hist(ape.flatten(), bins=100, density=True, range=(0, 30), color='tomato')
ax_hist.set_xlabel('Error (%)', fontsize=15)
ax_hist.set_ylabel('Density', fontsize=15)
ax_hist.spines['top'].set_visible(False)
ax_hist.spines['right'].set_visible(False)
# percentiles
p1,p2,p3 = np.percentile(ape, [25, 50, 75])
ax_hist.axvline(p2, color='black', linestyle='--')
ax_hist.axvline(p1, color='black', linestyle='dotted')
ax_hist.axvline(p3, color='black', linestyle='dotted')
ax_hist.text(p2, 0.2, fr'${p2:.1f}^{{+{p3-p2:.1f}}}_{{-{p2-p1:.1f}}}\%$', fontsize=12, verticalalignment='top')
ax_box.boxplot(ape.flatten(), vert=False, autorange=False, widths=0.5, patch_artist=True, showfliers=False, boxprops=dict(facecolor='tomato'), medianprops=dict(color='black'))
ax_box.set(yticks=[])
ax_box.spines['left'].set_visible(False)
ax_box.spines['right'].set_visible(False)
ax_box.spines['top'].set_visible(False)
fig.suptitle('APE of the Neural Network', fontsize=20)
plt.xlim(0, 30)
fig.tight_layout()
plt.savefig("plots/ape_NN.png")
plt.clf()
# ----- Save the model --------------------------------------------------------------------------------------------------------------------------------------------
torch.save(model.state_dict(), 'data/pytorch_state_dict.pt')
print("Model trained and saved")