-
Notifications
You must be signed in to change notification settings - Fork 2
/
show.py
85 lines (63 loc) · 2.24 KB
/
show.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import test
import settings as S
import flow
import train
import infer
import matplotlib.pyplot as plt
import numpy as np
import plac
import time
import test
def display_images(images, names=None):
fig = plt.figure()
sqrt = int(np.ceil(np.sqrt(len(images))))
for i in range(len(images)):
fig.add_subplot(sqrt, sqrt, i+1)
plt.imshow(images[i].squeeze())
if names is not None:
plt.title(names[i])
plt.show()
def predict_and_show(df, argmax=True):
model = train.build_model()
model = train.load_weights(model, S.MODELSTRING_BEST)
for imgs, mask in df:
pre = imgs[...,:3]
post = imgs[...,3:]
pred = model.predict(imgs)
mask = infer.convert_prediction(mask)
maxed = infer.convert_prediction(pred, argmax=True)
pred, _ = infer.convert_prediction(pred, argmax=False)
pred1 = pred[...,0]
pred2 = pred[...,1]
try:
display_images([pre, post, maxed, pred1, pred2, mask], ["Pre", "Post", "Argmax", "Pred1", "Pred2", "Ground Truth"])
except Exception as exc:
[print(x.shape) for x in [pre,post,maxed,pred1,pred2,mask]]
raise exc
def predict_and_show_no_argmax(df):
return predict_and_show(df, argmax=False)
def show(df):
i = 0
for imgs, masks in df:
pre = imgs[...,:3]
post = imgs[...,3:]
mask = infer.convert_prediction(masks)
prename = df.samples[i][0].img_name
postname = df.samples[i][1].img_name
display_images([pre, post, mask], [prename, postname, "Mask"])
i += 1
def main(predict: ("Do prediction", "flag", "p"),
image: ("Show this specific image", "option", "i")=""):
df = flow.Dataflow(files=flow.get_validation_files(), shuffle=True, batch_size=1, buildings_only=True, return_stacked=True, transform=0.5, return_average=False, return_postmask=True)
if image != "":
for i in range(len(df.samples)):
if image in df.samples[i][0].img_name or image in df.samples[i][1].img_name:
df.samples = [df.samples[i]]
show(df)
if predict:
predict_and_show(df)
else:
show(df)
if __name__ == '__main__':
S.BATCH_SIZE = 1
plac.call(main)