-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
74 lines (64 loc) · 2.72 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from keras.callbacks import Callback
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import pickle
import sys
import json
import operator
def read_params(path_to_config):
config = path_to_config
if len(sys.argv) == 2:
config = sys.argv[1]
with open(config) as f:
config = json.load(f)
return config
class AucMetricHistory(Callback):
def __init__(self,save_best_by_auc=False,path_to_save=None):
super(AucMetricHistory, self).__init__()
self.save_best_by_auc=save_best_by_auc
self.path_to_save = path_to_save
self.best_auc = 0
self.best_epoch = 1
if save_best_by_auc and (path_to_save is None):
raise ValueError('Specify path to save the model')
def on_epoch_end(self, epoch, logs={}):
x_val,y_val = self.validation_data[0],self.validation_data[1]
y_pred = self.model.predict(x_val,batch_size=len(y_val), verbose=0)
if isinstance(y_pred,list):
y_pred = y_pred[0]
current_auc = roc_auc_score(y_val, y_pred)
logs['val_auc'] = current_auc
if current_auc > self.best_auc:
self.best_auc = current_auc
self.best_epoch = epoch
def single_auc_loging(history, title, path_to_save):
"""
Function for ploting nn-classifier performance. It makes two subplots.
First subplot with train and val losses
Second with val auc
Function saves plot as a picture and as a pkl file
:param history: history field of history object, witch returned by model.fit()
:param title: Title for picture (also used as filename)
:param path_to_save: Path to save file
:return:
"""
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 12))
if 'loss' in history.keys():
loss_key = 'loss' # for simple NN
elif 'class_out_loss' in history.keys():
loss_key = 'class_out_loss' # for DAL NN
else:
raise ValueError('Not found correct key for loss information in history')
ax1.plot(history[loss_key], label='cl train loss')
ax1.plot(history['val_%s' % loss_key], label='cl val loss')
ax1.legend()
min_loss_index, max_loss_value = min(enumerate(history['val_loss']), key=operator.itemgetter(1))
ax1.set_title('min_loss_%.3f_epoch%d' % (max_loss_value, min_loss_index))
ax2.plot(history['val_auc'])
max_auc_index, max_auc_value = max(enumerate(history['val_auc']), key=operator.itemgetter(1))
ax2.set_title('max_auc_%.3f_epoch%d' % (max_auc_value, max_auc_index))
f.suptitle('%s' % (title))
plt.savefig('%s/%s.png' % (path_to_save, title), figure=f)
plt.close()
with open('%s/%s.pkl' % (path_to_save, title), 'wb') as output:
pickle.dump(history, output, pickle.HIGHEST_PROTOCOL)