-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
62 lines (45 loc) · 1.27 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#from Dataloader import NYU_Depth_V2
from matplotlib import pyplot as plt
import numpy as np
import torch
from time import time
from loss.loss_functions import l1_loss, smooth_loss
from preprocessing.data_transformations import get_split, denormalize
'''
gas = NYU_Depth_V2()
print(gas.images.shape)
print(gas.images[0].shape)
ide = np.swapaxes(gas.images[0], 0, 2).astype('uint8')
ide2 = np.swapaxes(ide, 0, 1).astype('uint8')
plt.imshow(ide2,cmap="gray")
plt.show()
'''
train_data, val_data, test_data = get_split(train=True)
train_data.initBatch()
imgs, dpts = train_data.getBatch()
img = denormalize(imgs)
imgs = np.swapaxes(np.swapaxes(imgs.numpy(),1,2),2,3)
dpts = np.swapaxes(np.swapaxes(dpts.numpy(),1,2),2,3)
print('Train')
for i in range(5):
im = imgs[i,:,:,:].astype('uint8')
plt.figure()
plt.imshow(im)
plt.show()
plt.figure()
plt.imshow(dpts[i,:,:,0])
plt.show()
print('Test data')
test_data.initBatch()
imgs, dpts = test_data.getBatch()
imgs = denormalize(imgs)
imgs = np.swapaxes(np.swapaxes(imgs.numpy(),1,2),2,3)
dpts = np.swapaxes(np.swapaxes(dpts.numpy(),1,2),2,3)
for i in range(5):
im = imgs[i,:,:,:].astype('uint8')
plt.figure()
plt.imshow(im)
plt.show()
plt.figure()
plt.imshow(dpts[i,:,:,0])
plt.show()