-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
175 lines (143 loc) · 6.24 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os, re
import pdb
import argparse
from matplotlib.pyplot import plot
from plotUtils import *#Plotter, plot_calibration_curve, plot_weights
import random
import numpy as np
import torch
from torch.utils.data import random_split, DataLoader, ConcatDataset
import pytorch_lightning as pl
from models import *
from data import *
#################################################
# set random seeds
def fix_randomness(seed: int, deterministic: bool = False) -> None:
pl.seed_everything(seed, workers=True)
if deterministic:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":16:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
fix_randomness(42, True)
def get_flat_array(result_tensor, idx):
# result_tensor is a list
# list of length ≃ instances/BATCH_SIZE.
# each element is a tuple: (probs, weights)
# each tuple element is a tensor of shape (BATCH_SIZE, 1)
# get the elements (probs/weights) as flat np.array in CPU
elements = map(lambda x: x[idx].cpu().detach().numpy(), result_tensor)
# first reshape to flatten the arrays to (BATCH_SIZE, ) and then np.concatenate(list)
elements = map(lambda x: x.reshape(-1), elements)
elements = np.concatenate(list(elements))
return elements
#################################################
# Arugment parsing
parser = argparse.ArgumentParser(usage="usage: %(prog)s [opts]")
parser.add_argument('-m', '--model', action='store', type=str, dest='model', required=True, help='The model used for evaluation.')
parser.add_argument('-n', '--batchNorm', action='store_true', default=False, help='Do batch normalization.')
parser.add_argument('-t', '--transform', choices=['None', 'NormPerImg', 'NormGlob', 'LogScale'], default='None', help='Type of transform to perform on the input data (e.g., normalizaing everything to be in the range [0,1]).')
parser.add_argument('--stride', type=int, default=3, help='Stride of filter')
parser.add_argument('-b', '--batchSize', type=int, default=256, help='Batch size') #128 is better for atlasgpu
parser.add_argument('-l', '--logVersion', type=str, default='version_0', help='Version of the log to use for plotting')
parser.add_argument('-d', '--dataPath', type=str, default='/data/whopkins/ILDCaloSim/e-_Jun3/test/', help='Path to data files.')
parser.add_argument('-a', '--alt_key', type=str, default='RC10', help='Key for alternative range cut, e.g., RC10.')
opts = parser.parse_args()
model_path = opts.model
#################################################
# configuration
BATCH_SIZE = opts.batchSize
NUM_WORKERS = 2
MODELNAME=model_path.split('/')[-1].rstrip('pt').rstrip('.')
batchNormStr = ''
if opts.batchNorm:
batchNormStr = '_batchNorm'
suffix = f'_stride{opts.stride}'+batchNormStr
transform = None
GLOBAL_FEATURES = []
for i in MODELNAME.split('_'):
if '_G' in i: GLOBAL_FEATURES.append(i.lstrip('_G'))
if 'trans' in i:
transformStr = i.lstrip('trans')
if transformStr == 'NormGlob':
global_max = get_global_max(opts.dataPath, opts.alt_key)
transform = locals()[transformStr](global_max)
else:
transform = locals()[transformStr]()
suffix += '_'+i
if len(GLOBAL_FEATURES) > 0:
suffix += '_'+'_'.join(GLOBAL_FEATURES)
suffix+='_'+opts.alt_key
#################################################
# load test dataset
dataset = get_HDF5_dataset(opts.dataPath+'showers-10kE10GeV-'+opts.alt_key+'-1.hdf5')
dataset_t = get_tensor_dataset(dataset, GLOBAL_FEATURES, transform=transform)
#dataset = get_HDF5_dataset('/data/ekourlitis/ILDCaloSim/e-_large/test/showers-10kE10GeV-RC10-95.hdf5')
#dataset_t = get_tensor_dataset(dataset, GLOBAL_FEATURES)
# load nominal dataset (just for plotting)
nom_dataset = get_HDF5_dataset(opts.dataPath+'showers-10kE10GeV-RC01-1.hdf5')
# get the labels
labels = np.array(list(map(lambda x: x[1].numpy(), dataset_t))).reshape(-1)
# number of instances/examples
instances = len(dataset_t)
print("Number of instances to predict: %i" % instances)
test_loader = DataLoader(dataset_t,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS)
#################################################
'''
# get some random training images
dataiter = iter(test_loader)
# images, labels = dataiter.next()
images, features, labels = dataiter.next()
pdb.set_trace()
'''
#################################################
# init model
inputShape = next(iter(test_loader))[0].numpy().shape[1:]
print("Shape of input:",inputShape)
if len(GLOBAL_FEATURES):
num_features = next(iter(test_loader))[1].numpy().shape[1:][0]
print("Number of high-level (global) features:", num_features)
# init model
if len(GLOBAL_FEATURES):
model = Conv3DModelGF(inputShape,
num_features,
use_batchnorm=opts.batchNorm,
use_dropout=True,
stride=opts.stride,
hidden_layers_in_out=[(512,512), (512,512)]
)
else:
model = Conv3DModel(inputShape,
use_batchnorm=opts.batchNorm,
use_dropout=True,
stride=opts.stride,
hidden_layers_in_out=[(512,512), (512,512)]
)
model.load_state_dict(torch.load(model_path))
# init a trainer
# use GPU1
trainer = pl.Trainer(gpus=[1])
# accelerator='dp')
# inference
result_tensor = trainer.predict(model,
test_loader,
return_predictions=True)
probs = get_flat_array(result_tensor, 0)
weights = get_flat_array(result_tensor, 1)
# clip on maxWeight
#maxWeight = 500
#weights[weights > maxWeight] = 0
# how many zeros?
zero_counter = np.count_nonzero(weights==0)
print("Zero weights fraction: %0.3f%% " % ( (zero_counter/len(weights))*100 ))
#################################################
# Plotting
plots = Plotter(nom_dataset, dataset, weights)
# plot_weights(weights, suffix=suffix)
plots.plot_event_observables(suffix=suffix)
# plot_calibration_curve(labels, probs)
# csvLoggerPath = "logs/"+MODELNAME+'_csv/'+opts.logVersion+'/metrics.csv'
# print(csvLoggerPath)
# plot_metrics(csvLoggerPath, suffix=suffix)