Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced percentile normalization for synthesis challenge metrics #700

Merged
merged 6 commits into from
Aug 2, 2023
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,30 @@ def __fix_2d_tensor(input_tensor):
else:
return input_tensor

def __percentile_clip(input_tensor, reference_tensor=None, p_min=0.5, p_max=99.5, strictlyPositive=True):
"""Normalizes a tensor based on percentiles. Clips values below and above the percentile.
Percentiles for normalization can come from another tensor.

Args:
input_tensor (torch.Tensor): Tensor to be normalized based on the data from the reference_tensor.
If reference_tensor is None, the percentiles from this tensor will be used.
reference_tensor (torch.Tensor, optional): The tensor used for obtaining the percentiles.
p_min (float, optional): Lower end percentile. Defaults to 0.5.
p_max (float, optional): Upper end percentile. Defaults to 99.5.
strictlyPositive (bool, optional): Ensures that really all values are above 0 before normalization. Defaults to True.

Returns:
torch.Tensor: The input_tensor normalized based on the percentiles of the reference tensor.
"""
reference_tensor = input_tensor if reference_tensor is None else reference_tensor
v_min, v_max = np.percentile(reference_tensor, [p_min,p_max]) #get p_min percentile and p_max percentile

# set lower bound to be 0 if strictlyPositive is enabled
v_min = max(v_min, 0.0) if strictlyPositive else v_min
output_tensor = np.clip(input_tensor,v_min,v_max) #clip values to percentiles from reference_tensor
output_tensor = (output_tensor - v_min)/(v_max-v_min) #normalizes values to [0;1]
return output_tensor

for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row[headers["subjectid"]]
overall_stats_dict[current_subject_id] = {}
Expand All @@ -219,9 +243,9 @@ def __fix_2d_tensor(input_tensor):
# Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range)
normalize = parameters.get("normalize", True)
if normalize:
v_max = gt_image_infill.max()
output_infill /= v_max
gt_image_infill /= v_max
reference_tensor = target_image * ~mask #use all the tissue that is not masked for normalization
gt_image_infill = __percentile_clip(gt_image_infill, reference_tensor=reference_tensor, p_min=0.5, p_max=99.5, strictlyPositive=True)
output_infill = __percentile_clip(output_infill, reference_tensor=reference_tensor, p_min=0.5, p_max=99.5, strictlyPositive=True)

overall_stats_dict[current_subject_id][
"ssim"
Expand Down Expand Up @@ -258,6 +282,7 @@ def __fix_2d_tensor(input_tensor):
gt_image_infill, output_infill
).item()

#TODO: use data_range=1.0 as parameter for PSNR when the Pull request is accepted that introduces the data_range parameter!
# PSNR - similar to pytorch PeakSignalNoiseRatio until 4 digits after decimal point
overall_stats_dict[current_subject_id]["psnr"] = peak_signal_noise_ratio(
gt_image_infill, output_infill
Expand Down
Loading