-
Notifications
You must be signed in to change notification settings - Fork 41
/
visualize_attention_map.py
69 lines (59 loc) · 2.4 KB
/
visualize_attention_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# encoding: utf-8
"""
@author: rentianhe
@contact: [email protected]
"""
import numpy as np
import cv2
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
def visualize_grid_attention(img_path, save_path, attention_mask, ratio=0.5, save_image=True, save_original_image=True, quality=100):
"""
img_path: where to load the image
save_path: where to save the image
attention_mask: the 2-D attention mask on your image, e.g: np.array (h, w) or (w, h)
ratio: scaling factor to scale the output h and w
quality: save image quality
"""
print("load image from: " + img_path)
img = Image.open(img_path)
img_h, img_w = img.size[0], img.size[1]
plt.subplots(nrows=1, ncols=1, figsize=(0.02 * img_h, 0.02 * img_w))
# scale the image
img_h, img_w = int(img.size[0] * ratio), int(img.size[1] * ratio)
img = img.resize((img_h, img_w))
plt.imshow(img, alpha=1)
plt.axis('off')
# normalize the attention map
mask = cv2.resize(attention_mask, (img_h, img_w)) # you may change the (img_w, img_h) order to adjust the attention mask
normed_mask = mask / mask.max()
normed_mask = cv2.resize(normed_mask, img.size)[..., np.newaxis]
# put the attention map on the original image
result = (img * normed_mask).astype("uint8")
plt.imshow(result, alpha=1)
# save image
if save_image:
# build save path
if not os.path.exists(save_path):
os.mkdir(save_path)
assert save_image is not None, "you need to set where to store the picture"
img_name = img_path.split('/')[-1].split('.')[0] + "_with_attention.jpg"
img_with_attention_save_path = os.path.join(save_path, img_name)
# pre-process before saving
print("save image to: " + save_path)
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.savefig(img_with_attention_save_path, dpi=quality)
# save original image
if save_original_image:
# build save path
if not os.path.exists(save_path):
os.mkdir(save_path)
print("save original image at the same time")
img_name = img_path.split('/')[-1].split('.')[0] + "_original.jpg"
original_image_save_path = os.path.join(save_path, img_name)
img.save(original_image_save_path, quality=quality)