-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
115 lines (83 loc) · 2.9 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# importing required libraries
import numpy as np
import matplotlib.pyplot as plt
import argparse
import json
import os
from workspace_utils import active_session
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from collections import OrderedDict
from PIL import Image
# Imports other functions
from input_args import predict_args
# Defining input arguments
args = predict_args()
data_dir = args.data_dir
save_dir = args.save_dir
json_file = args.category_names
arch = args.arch
gpu = args.gpu
image_path = args.image_path
topk = args.top_k
# Function to load checkpoint and rebuild model
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
# Defining training model
#model = models.{arch}(pretrained=True)
model = getattr(models, arch)(pretrained=True) if hasattr(models, arch) else None
# Freeze parameters
for param in model.parameters():
param.requires_grad = False
model.classifier = checkpoint['classifier']
model.class_to_idx = checkpoint['class_to_idx']
model.load_state_dict(checkpoint['state_dict'])
return model
model = load_checkpoint('save_directory/checkpoint.pth')
# Function for preprocessing image for prediction
def process_image(image):
# TODO: Process a PIL image for use in a PyTorch model
pil_image = Image.open(image)
image_transforms = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Apply transform to image
image = image_transforms(pil_image)
return image
def predict(image_path, model, topk):
''' Predict the class (or classes) of an image using a trained deep learning model.
'''
# Use GPU if it's available
if torch.cuda.is_available() and gpu==gpu:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# TODO: Implement the code to predict the class from an image file
model.to(device)
model.eval()
# Convert 2D image to 1D vector
img = process_image(image_path)
img = img.unsqueeze(0)
# Calculate the class probabilities (softmax) for img
with torch.no_grad():
logps = model.forward(img.to(device))
probs = torch.exp(logps)
top_p, top_class = probs.topk(topk, dim=1)
return top_p, top_class
def predict_image():
with open(json_file, "r") as file:
cat_to_name = json.load(file)
image_process = process_image(image_path)
probs, classes = predict(image_path, model, topk)
image_labels = [cat_to_name[str(i)] for i in classes.cpu().numpy().tolist()[0]]
print(np.array(probs[0]))
print(image_labels)
if __name__ == "__main__":
predict_image()