-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathsave_multi_object_results.py
52 lines (39 loc) · 1.59 KB
/
save_multi_object_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from torch.utils.data import DataLoader
from multi_object.utils import get_multi_object_net, get_multi_object_test_set
from utils.utils import Settings
import pickle
args = Settings()
net = get_multi_object_net()
testSet = get_multi_object_test_set()
testDataloader = DataLoader(testSet, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, collate_fn=testSet.collate_fn)
net.train_flag = False
it_testDataloader = iter(testDataloader)
len_test = len(it_testDataloader)
output_all = []
hist_test = []
mask_test = []
fut_test = []
pred_test = []
lines_test = []
path_list = testSet.dataset['path']
for j in range(len_test):
hist, fut, mask_hist, mask_fut, lines, mask_lines = next(it_testDataloader)
len_pred = fut.shape[0]
hist = hist.to(args.device)
fut = fut.cpu().detach().numpy()
mask_hist = mask_hist.to(args.device)
mask_fut = mask_fut.cpu().detach().numpy()
lines = lines.cpu().detach().numpy()
mask_lines = mask_lines.cpu().detach().numpy()
pred_fut = net(hist, mask_hist, len_pred)
pred_fut = pred_fut[-fut.shape[0]:, :, :, None, :].detach().cpu().numpy()
lines_test.append(lines)
hist_test.append(hist.cpu().detach().numpy())
mask_test.append(mask_fut)
fut_test.append(fut)
pred_test.append(pred_fut)
with open('./results/' + args.load_name + '.pickle', 'wb') as handle:
pickle.dump({'hist':hist_test, 'mask':mask_test, 'fut':fut_test,
'pred':pred_test, 'path':path_list, 'lines': lines_test},
handle, protocol=pickle.HIGHEST_PROTOCOL)