forked from NeuroSYS-pl/objects_counting_dmap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
232 lines (187 loc) · 6.04 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""This script apply a chosen model on a given image.
One needs to choose a network architecture and provide the corresponding
state dictionary.
Example:
$ python infer.py -n UNet -c mall_UNet.pth -i seq_000001.jpg
The script also allows to visualize the results by drawing a resulting
density map on the input image.
Example:
$ $ python infer.py -n UNet -c mall_UNet.pth -i seq_000001.jpg --visualize
"""
import os
import click
import torch
import numpy as np
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from PIL import Image
from skimage.feature import peak_local_max
from models import UNet, FCRN_A
from looper import calculate_classifications, find_predicted_dots
@click.command()
@click.option(
"-i",
"--infer_path",
type=click.Path(),
required=True,
help="A path to an input image to infer.",
)
@click.option(
"-n",
"--network_architecture",
type=click.Choice(["UNet", "FCRN_A"]),
required=True,
help="Model architecture.",
)
@click.option(
"-c",
"--checkpoint",
type=click.File("r"),
required=True,
help="A path to a checkpoint with weights.",
)
@click.option(
"--unet_filters",
default=64,
help="Number of filters for U-Net convolutional layers.",
)
@click.option(
"--convolutions", default=2, help="Number of layers in a convolutional block."
)
@click.option(
"--one_channel",
is_flag=True,
help="Turn this on for one channel images (required for ucsd).",
)
@click.option(
"--pad", is_flag=True, help="Turn on padding for input image (required for ucsd)."
)
@click.option(
"-v",
"--valid_path",
type=click.File("r"),
help="A path to an answer image containing true keypoints.",
)
@click.option("--visualize", is_flag=True, help="Visualize predicted density map.")
@click.option("--save", type=click.Path(exists=False), help="Save visualized plots to path.")
def infer(
infer_path: str,
valid_path: str,
network_architecture: str,
checkpoint: str,
unet_filters: int,
convolutions: int,
one_channel: bool,
pad: bool,
visualize: bool,
save: str,
):
if (os.path.isdir(infer_path)):
files = os.listdir(infer_path)
os.makedirs(save)
sum_ = 0
for f in files:
n = _infer(os.path.join(infer_path, f),
valid_path,
network_architecture,
checkpoint,
unet_filters,
convolutions,
one_channel,
pad,
visualize,
os.path.join(save, f))
sum_ += n
print(f"Total objects found in {len(files)} images: {sum_}")
else:
_infer(infer_path,
valid_path,
network_architecture,
checkpoint,
unet_filters,
convolutions,
one_channel,
pad,
visualize,
save)
def _infer(
infer_path: str,
valid_path: str,
network_architecture: str,
checkpoint: str,
unet_filters: int,
convolutions: int,
one_channel: bool,
pad: bool,
visualize: bool,
save: str,
):
"""Run inference for a single image."""
# use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# only UCSD dataset provides greyscale images instead of RGB
input_channels = 1 if one_channel else 3
# initialize a model based on chosen network_architecture
network = {"UNet": UNet, "FCRN_A": FCRN_A}[network_architecture](
input_filters=input_channels, filters=unet_filters, N=convolutions
).to(device)
# load provided state dictionary
# note: by default train.py saves the model in data parallel mode
network = torch.nn.DataParallel(network)
network.load_state_dict(torch.load(checkpoint.name, map_location=device))
network.eval()
img = Image.open(infer_path)
# padding was applied for ucsd images to allow down and upsampling
if pad:
img = Image.fromarray(np.pad(img, 1, "constant", constant_values=0))
# network's output represents a density map
density_map = network(TF.to_tensor(img).unsqueeze_(0))
# note: density maps were normalized to 100 * no. of objects
n_objects = torch.sum(density_map).item() / 100
print(f"The number of objects found: {n_objects}")
answer_key = None
if valid_path is not None and os.path.isfile(valid_path):
answer_key = np.array(Image.open(valid_path.name))
print(f"The true number of objects: {np.count_nonzero(answer_key)}")
if visualize:
_visualize(img, density_map.squeeze().cpu().detach().numpy(), n_objects, answer_key, save)
return n_objects
def _visualize(img, dmap, n_objects, key, save=None):
"""Draw a density map onto the image."""
# keep the same aspect ratio as an input image
fig, axes = plt.subplots(1, 3)
# turn off axis ticks
[ax.axis("off") for ax in axes]
# display raw density map
axes[0].imshow(dmap, cmap="hot")
axes[0].set_title("Raw")
#display intermediate step
axes[1].imshow(np.zeros(dmap.shape), cmap="gray")
axes[1].set_title("Predicted vs true")
# display og image
axes[2].imshow(img)
axes[2].set_title("TPs, FPs and FNs")
dots = find_predicted_dots(dmap)
peaks = np.nonzero(dots)
axes[1].scatter(x=peaks[1], y=peaks[0], c="#ff0000", s=10, marker="x")
# overlay true keypoints
if key is not None:
edgeColor = "#00ff00"
true = np.nonzero(key)
axes[1].scatter(x=true[1], y=true[0], c="none", s=10, marker="s", edgecolors=edgeColor)
TP, FP, FN = calculate_classifications(key, dots)
tps = np.nonzero(TP)
fps = np.nonzero(FP)
fns = np.nonzero(FN)
axes[2].scatter(x=tps[1], y=tps[0], c="#00ff00", s=20, marker="x")
axes[2].scatter(x=fps[1], y=fps[0], c="#ff0000", s=20, marker="x")
axes[2].scatter(x=fns[1], y=fns[0], c="#ffff00", s=20, marker="x")
else:
pred = np.nonzero(dots)
axes[2].scatter(x=pred[1], y=pred[0], c="#ff0000", s=20, marker="x")
if save is not None:
fig.savefig(save, bbox_inches="tight")
plt.show()
if __name__ == "__main__":
infer()