-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutility_fun.py
executable file
·90 lines (83 loc) · 2.55 KB
/
utility_fun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""
Created on Mon May 24 17:19:16 2021
@author: pankaj.mishra
"""
from scipy.ndimage import gaussian_filter, median_filter
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.measure import label
import os
def Normalise(score_map):
max_score = score_map.max()
min_score = score_map.min()
scores = (score_map - min_score) / (max_score - min_score)
return scores
def Mean_var(score_map):
mean = np.mean(score_map)
var = np.var(score_map)
return mean, var
def Filter(score_map, type=0):
'''
Parameters
----------
score_map : score map as tensor or ndarray
type : Int, optional
DESCRIPTION. The values are:
0 = Gaussian
1 = Median
Returns
-------
score: Filtered score
'''
if type ==0:
score = gaussian_filter(score_map, sigma=4)
if type == 1:
score = median_filter(score_map, size=3)
return score
def Binarization(mask, thres = 0., type = 0):
if type == 0:
mask = np.where(mask > thres, 1., 0.)
elif type ==1:
mask = np.where(mask > thres, mask, 0.)
return mask
def plot(image,grnd_truth, score, title):
print('Images ' , image.shape)
print('grnd_truth ' , grnd_truth.shape)
print('score ' , score.shape)
score = score.cpu().detach().numpy()
for k in range(image.shape[0]):
plt.subplot(131)
plt.imshow(image[k].permute(1,2,0))
plt.subplot(132)
plt.imshow(grnd_truth[k].permute(1,2,0))
plt.xlabel('ground truth')
plt.subplot(133)
plt.imshow( np.transpose(score[k], (1,2,0)))
plt.xlabel('predicted')
# plt.title('Anomaly score')
# plt.imshow(score[0].permute(1,2,0), cmap='Reds')
plt.colorbar()
#plt.pause(1)
title = title.split('.')[0]+'_'+str(k) + '.png'
plt.savefig( title)
def binImage(heatmap, thres=0 ):
_, heatmap_bin = cv2.threshold(heatmap , thres , 255 , cv2.THRESH_BINARY+cv2.THRESH_OTSU)
# t in the paper
#_, heatmap_bin = cv2.threshold(heatmap , 178 , 255 , cv2.THRESH_BINARY)
return heatmap_bin
def selectMaxConnect(heatmap):
labeled_img, num = label(heatmap, connectivity=2, background=0, return_num=True)
max_label = 0
max_num = 0
for i in range(1, num+1):
if np.sum(labeled_img == i) > max_num:
max_num = np.sum(labeled_img == i)
max_label = i
lcc = (labeled_img == max_label)
if max_num == 0:
lcc = (labeled_img == -1)
lcc = lcc + 0
return lcc