-
Notifications
You must be signed in to change notification settings - Fork 0
/
histogram_match
47 lines (37 loc) · 1.83 KB
/
histogram_match
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import SimpleITK as sitk
import sys
def histogram_match(input_img, ref_img, mask_img):
# input image, reference image, mask image
# histogram matching inside the mask
input_mask_img = sitk.Mask(input_img, mask_img)
ref_mask_img = sitk.Mask(ref_img, mask_img)
input_mask_arr = sitk.GetArrayFromImage(input_mask_img)
ref_mask_arr = sitk.GetArrayFromImage(ref_mask_img)
input_arr = sitk.GetArrayFromImage(input_img)
# vectorize arrays
img_shape = input_arr.shape
input_vec = input_arr.reshape(-1)
input_mask_vec = input_mask_arr.reshape(-1)
ref_mask_vec = ref_mask_arr.reshape(-1)
# count unique values for input_vector/input_mask_vector/reference_vector
# 'inverse': to reconstruct the original vector from unique values.
# count: to count how many times each unique value appears
unique_i, inverse_i, counts_i = np.unique(input_vec, return_inverse=True, return_counts=True)
unique_im, counts_im = np.unique(input_mask_vec, return_counts=True)
unique_rm, counts_rm = np.unique(ref_mask_vec, return_counts=True)
# match im with rm (im = image mask, rm = reference mask)
# calculate quantiles
im_cum = np.cumsum(counts_im).astype(np.float32)
im_qtl = im_cum/im_cum[-1]
rm_cum = np.cumsum(counts_rm).astype(np.float32)
rm_qtl = rm_cum/rm_cum[-1]
# interpolate: rm_qtl -> unique_rm to im_qtl -> interp_unique_im
interp_unique_im = np.interp(im_qtl, rm_qtl, unique_rm)
# expand to the whole image space
# interpolate: unique_im -> interp_unique_img to unique_i -> interp_unique_i
interp_unique_i = np.interp(unique_i, unique_im, interp_unique_im)
match_arr = interp_unique_i[inverse_i].reshape(img_shape)
match_img = sitk.GetImageFromArray(match_arr)
match_img.CopyInformation(input_img)
return match_img