-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualize.py
57 lines (44 loc) · 1.67 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from modules import *
from read_dataset import prepare_data, transform_data
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
_, test_dataloader = prepare_data(batch_size=16)
ckp_path = "./wide_resnet.pt"
net = models.wide_resnet50_2(pretrained=True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0005, momentum=0.9)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 256)
net.fc = net.fc.cuda() if use_cuda else net.fc
if use_cuda:
trained_model = torch.load(ckp_path)
else:
trained_model = torch.load(ckp_path, map_location=torch.device('cpu'))
net.load_state_dict(trained_model)
net = net.cuda() if use_cuda else net
net.eval()
train_dataset, test_dataset = transform_data()
def visualize_model(net, num_images=4):
images_so_far = 0
fig = plt.figure(figsize=(15, 10))
for i, data in enumerate(test_dataloader):
print(i)
inputs, labels = data
if use_cuda:
inputs, labels = inputs.cuda(), labels.cuda()
outputs = net(inputs)
_, preds = torch.max(outputs.data, 1)
preds = preds.cpu().numpy() if use_cuda else preds.numpy()
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(2, num_images//2, images_so_far)
ax.axis('off')
ax.set_title('prediction: {}'.format(test_dataset.labels[preds[j]]))
inputs = inputs.cpu()
imshow(inputs[j].permute(1, 2, 0))
if images_so_far == num_images:
return plt.show()
if __name__ == "__main__":
plt.ion()
visualize_model(net)
plt.ioff()