-
Notifications
You must be signed in to change notification settings - Fork 39
/
mvf_post_training.py
125 lines (98 loc) · 3.23 KB
/
mvf_post_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Created by albert aparicio on 28/10/16
# coding: utf-8
# This import makes Python use 'print' as in Python 3.x
from __future__ import print_function
import os
import h5py
import matplotlib
import numpy as np
from ahoproc_tools.error_metrics import RMSE
from keras.models import model_from_json
from keras.optimizers import RMSprop
from tfglib.utils import apply_context
matplotlib.use('TKagg')
from matplotlib import pyplot as plt
#######################
# Sizes and constants #
#######################
batch_size = 300
nb_epochs = 700
learning_rate = 0.00000055
context_size = 1
##############
# Load model #
##############
print('Loading model...', end='')
with open('models/mvf_model.json', 'r') as model_json:
model = model_from_json(model_json.read())
model.load_weights('models/mvf_weights.h5')
rmsprop = RMSprop(lr=learning_rate)
model.compile(loss='mae', optimizer=rmsprop)
#############
# Load data #
#############
# Load training statistics
with h5py.File('models/mvf_train_stats.h5', 'r') as train_stats:
src_train_mean = train_stats['src_train_mean'].value
src_train_std = train_stats['src_train_std'].value
trg_train_mean = train_stats['trg_train_mean'].value
trg_train_std = train_stats['trg_train_std'].value
train_stats.close()
# Load test data
print('Loading test data...', end='')
with h5py.File('data/test_datatable.h5', 'r') as test_datatable:
test_data = test_datatable['test_data'][:, :]
test_datatable.close()
src_test_data = test_data[:, 41:43] # Source data
src_test_data[:, 0] = (src_test_data[:, 0] - src_train_mean) / src_train_std
trg_test_data = test_data[:, 84:86] # Target data
# Apply context
src_test_data_context = np.column_stack((
apply_context(src_test_data[:, 0], context_size), src_test_data[:, 1]
))
print('done')
################
# Predict data #
################
print('Predicting')
prediction = model.predict(src_test_data_context)
# De-normalize predicted output
prediction[:, 0] = (prediction[:, 0] * trg_train_std) + trg_train_mean
#################
# Error metrics #
#################
# Compute and print RMSE of test data
rmse_test = RMSE(
trg_test_data[:, 0],
prediction[:, 0],
mask=trg_test_data[:, 1]
)
print('Test RMSE: ', rmse_test)
# Load training parameters and save loss curves
with h5py.File('training_results/baseline/mvf_history.h5', 'r') as hist_file:
loss = hist_file['loss'][:]
val_loss = hist_file['val_loss'][:]
epoch = hist_file['epoch'][:]
hist_file.close()
print('Saving loss curves')
plt.plot(epoch, loss, epoch, val_loss)
plt.legend(['loss', 'val_loss'], loc='best')
plt.grid(b=True)
plt.suptitle('Baseline MVF Loss curves')
plt.savefig(os.path.join('training_results', 'baseline', 'mvf_loss_curves.eps'),
bbox_inches='tight')
# # Histogram of predicted training data and training data itself
# plt.hist(prediction[:, 0], bins=100)
# plt.title('Prediction frames')
# plt.savefig('prediction_hist.png', bbox_inches='tight')
# plt.show()
# # Histogram of training samples
# plt.figure()
# plt.hist(vf_gtruth, bins=100)
# plt.title('Training target frames')
# plt.savefig('gtruth_hist.png', bbox_inches='tight')
# plt.show()
print('========================' + '\n' +
'======= FINISHED =======' + '\n' +
'========================')
exit()