-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathimage_utils.py
executable file
·49 lines (37 loc) · 1.39 KB
/
image_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import math
import cv2
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
HEIGHT = 228
WIDTH = 304
def scale_image(img, scale=None):
"""Resize/scale an image. If a scale is not provided, scale it closer to HEIGHT x WIDTH."""
# if scale is None, scale to the longer size
if scale is None:
scale = max(WIDTH / img.shape[1], HEIGHT / img.shape[0])
new_size = (math.ceil(img.shape[1] * scale), math.ceil(img.shape[0] * scale))
image = cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST)
return image
def center_crop(img):
"""Center crop the input image to HEIGHT x WIDTH."""
corner = ((img.shape[0] - HEIGHT) // 2, (img.shape[1] - WIDTH) // 2)
img = img[corner[0]:corner[0] + HEIGHT, corner[1]:corner[1] + WIDTH]
return img
def img_transform(img):
"""Normalize the input image."""
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = data_transform(img)
return img
def save_img_and_pred(img, depth, save_path):
"""Plot an image and a corresponding prediction next to each other, then save the plot."""
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.subplot(1, 2, 2)
pred = np.transpose(depth, (1, 2, 0))
plt.imshow(pred[:, :, 0])
plt.savefig(save_path)