-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
55 lines (46 loc) · 1.76 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from collections import OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import json
from PIL import Image
from torch.autograd import Variable
import torchvision.models as models
import torch
from torch import nn, optim
import futility
import futility
import fmodel
parser = argparse.ArgumentParser(
description = 'Parser for predict.py'
)
parser.add_argument('input', default='./flowers/test/1/image_06752.jpg', nargs='?', action="store", type = str)
parser.add_argument('--dir', action="store",dest="data_dir", default="./flowers/")
parser.add_argument('checkpoint', default='./checkpoint.pth', nargs='?', action="store", type = str)
parser.add_argument('--top_k', default=5, dest="top_k", action="store", type=int)
parser.add_argument('--category_names', dest="category_names", action="store", default='cat_to_name.json')
parser.add_argument('--gpu', default="gpu", action="store", dest="gpu")
args = parser.parse_args()
path_image = args.input
number_of_outputs = args.top_k
device = args.gpu
path = args.checkpoint
def main():
model=fmodel.load_checkpoint(path)
with open('cat_to_name.json', 'r') as json_file:
cat_to_name = json.load(json_file)
probabilities = fmodel.predict(path_image, model, number_of_outputs, device)
labels = [cat_to_name[str(index + 1)] for index in np.array(probabilities[1][0])]
probability = np.array(probabilities[0][0])
i=0
while i < number_of_outputs:
print("{} with a probability of {}".format(labels[i], probability[i]))
i += 1
print("Finished Predicting!")
if __name__== "__main__":
main()