-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
88 lines (78 loc) · 3.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import absolute_import, division, print_function
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
import argparse
import os
import DSNNr as DSNNr
import math
import time
import matplotlib.pyplot as plt
import matplotlib as mpl
from cycler import cycler
def train(args):
if not os.path.isdir(args.global_name):
os.makedirs(args.global_name)
os.chdir(args.global_name)
with open("config.txt", "w") as f:
f.write(str(args))
f.close()
args.MCa_path = args.MCa
args.MCb_path = args.MCb
if not args.isTraining:
(MCa, MCb, MCa_spec, MCb_spec, MCa_weights, MCb_weights, maxObjCount) = DSNNr.get_data(args)
if args.isTraining:
inputOriaFile = "SherpaMCa.hadd.npz"
inputOribFile = "MadgaphMCa.hadd.npz"
MCa = np.load(inputOriaFile)["MCa"]
MCb = np.load(inputOribFile)["MCb"]
MCa_weights = np.load(inputOriaFile)["MCa_weights"]
MCb_weights = np.load(inputOribFile)["MCb_weights"]
MCa_spec = np.load(inputOriaFile)["MCa_spec"]
MCb_spec = np.load(inputOribFile)["MCb_spec"]
(
X_train,
X_test,
Y_train,
Y_test,
train_weights,
test_weights,
S_train,
S_test,
class_weights,
) = DSNNr.handle_data(args, MCa, MCb, MCa_weights, MCb_weights, MCa_spec, MCb_spec)
#model = DSNNr.basic_model(args, n_features)
n_features = len(args.features.split(","))
model = DSNNr.DS_model(n_features)
checkpoint = ModelCheckpoint('./saved_models/'+ "/model-{epoch:03d}.ckpt",
monitor='val_loss',
verbose=2,
save_freq='epoch',
save_best_only=True,
save_weights_only=False,
mode='min')
csvLogger = CSVLogger("trainingCSV.csv", separator=",", append=False)
earlyStopping = EarlyStopping(monitor='val_loss',
min_delta=0,
patience=20,
verbose=1,
restore_best_weights=True)
callbacks = [checkpoint, csvLogger, earlyStopping]
# -----------
# Train model
# -----------
start_train = time.time()
history = model.fit(X_train, Y_train,
epochs = 200,
batch_size = 50000,
validation_data = (X_test, Y_test, test_weights),
class_weight=class_weights,
sample_weight=train_weights,
verbose = 1,
callbacks = callbacks) # Train the model with the new callback
end_train = time.time()
#print('Y_test:',Y_test)
print("Time consumed in training: ",end_train - start_train)
if __name__ == "__main__":
args = DSNNr.handle_args()
train(args)