-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
106 lines (77 loc) · 2.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import pylab as plt
import json
import numpy as np
from torch import nn
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
class CustomLayers(nn.Module):
def __init__(self):
super(CustomLayers, self).__init__()
self.fc = nn.Linear(4096, 1)
self.fn=nn.Sigmoid()
def forward(self, x):
x = self.fc(x)
x = self.fn(x)
return x
def save_model(model,path):
"""
save the pytorch model.
Args:
model(torch.nn.Module): the trained model
path (str): path for saving the model
"""
torch.save(model,path)
def save_results_ae(saved_values,json_path,train_losses,test_losses):
"""
save the results of training for the autoencoder model in a json file.
Args:
saved_values(dict): an empty dict to be filled with train_losses and test_losses,and saved in a json file
train_losses (list): Saved train losses during training.
test_losses (list): Saved test losses during training.
path (str): path for saving the learning curve image
"""
saved_values["Train_Errors"]=train_losses
saved_values["Validation_Errors"]=test_losses
with open(json_path, 'w') as f:
json.dump(saved_values, f, indent=4,cls=NumpyEncoder)
f.close()
def save_results(saved_values,json_path,train_losses,test_losses,train_accuracy,test_accuracy,test_acc):
"""
save the results of training in a json file.
Args:
train_losses (list): Saved train losses during training.
test_losses (list): Saved test losses during training.
path (str): path for saving the learning curve image
label (str): the label(whether it is loss or accuracy)
"""
saved_values["Train_Errors"]=train_losses
saved_values["Validation_Errors"]=test_losses
saved_values["Train_Accuracy"]=train_accuracy
saved_values["Validation_Accuracy"]=test_accuracy
saved_values["Test_Accuracy"]=test_acc
with open(json_path, 'w') as f:
json.dump(saved_values, f, indent=4,cls=NumpyEncoder)
f.close()
def plot_curve(train_losses,test_losses,path,label):
"""
Plot the learning curve.
Args:
train_losses (list): Saved train losses during training.
test_losses (list): Saved test losses during training.
path (str): path for saving the learning curve image
label (str): the label(whether it is loss or accuracy)
"""
fig, ax = plt.subplots()
ax.plot(train_losses, label='Train '+label)
# ax.plot(ks_selected, train_errors_selected, label='Train Error')
# ax.plot(ks_selected, ntk_train_erros_selected, label='NTK train Error')
ax.plot(test_losses, label='Test '+label)
ax.set_xlabel("Epoch")
# ax.set_ylabel("Test/Train Error")
ax.set_title(label)
ax.legend()
fig.savefig(path)