Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torchmetric PSNR implementation and argument ordering #693

Merged
merged 22 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5409c32
Used proper torchmetric parameter order and replaced own PSNR impleme…
FelixSteinbauer Jul 18, 2023
dcb35a3
Merge branch 'master' into patch-2
sarthakpati Jul 19, 2023
9b4add6
Merge branch 'master' into patch-2
sarthakpati Jul 21, 2023
00ea9c5
Added peak_signal_noise_ratio_eps as dicussed
FelixSteinbauer Jul 21, 2023
960e47a
Added peak_signal_noise_ratio_eps call for synthesis case.
FelixSteinbauer Jul 21, 2023
e3e36d7
Added peak_signal_noise_ratio_eps also the the metrics init file
FelixSteinbauer Jul 21, 2023
0e2d72e
Removed Trailing whitespace (I think)
FelixSteinbauer Jul 21, 2023
5dee4ec
Merge branch 'master' into patch-2
sarthakpati Jul 26, 2023
08aac5f
Added epsilon and data_range as parameters to PSNR
FelixSteinbauer Jul 29, 2023
3ffbaa7
Unified PSNR versions
FelixSteinbauer Jul 29, 2023
af8c12d
Update generate_metrics.py for usage of the unified PSNR signature
FelixSteinbauer Jul 29, 2023
b42ca4e
Fixed IndentationError?
FelixSteinbauer Jul 29, 2023
fdea518
Fixed: unmatched ")"
FelixSteinbauer Jul 29, 2023
d2bf127
Removed unused sys package
FelixSteinbauer Jul 29, 2023
acef5d8
Added required sys package
FelixSteinbauer Jul 29, 2023
9472f06
Different description (comment) for non-torchmetrics PSNR
FelixSteinbauer Jul 31, 2023
5be2fca
Fixed wrong parenthesis in PSR definition
FelixSteinbauer Jul 31, 2023
70f8ff5
Trying different quotation marks
FelixSteinbauer Jul 31, 2023
dfa1665
Tried removing comment
FelixSteinbauer Aug 1, 2023
5f584af
Trying to revert parenthesis fix
FelixSteinbauer Aug 1, 2023
d8b1f8d
Returned to current state of code
FelixSteinbauer Aug 1, 2023
974fde4
Merge branch 'master' into patch-2
sarthakpati Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import yaml
from pprint import pprint
import pandas as pd
Expand Down Expand Up @@ -258,11 +259,14 @@ def __fix_2d_tensor(input_tensor):
gt_image_infill, output_infill
).item()

# PSNR - similar to pytorch PeakSignalNoiseRatio until 4 digits after decimal point
overall_stats_dict[current_subject_id]["psnr"] = peak_signal_noise_ratio(
gt_image_infill, output_infill
).item()

overall_stats_dict[current_subject_id]["psnr_eps"] = peak_signal_noise_ratio(
gt_image_infill, output_infill, epsilon=sys.float_info.epsilon
).item()

pprint(overall_stats_dict)
if outputfile is not None:
with open(outputfile, "w") as outfile:
Expand Down
31 changes: 19 additions & 12 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
import SimpleITK as sitk
import PIL.Image
import numpy as np
Expand All @@ -8,6 +7,7 @@
MeanSquaredError,
MeanSquaredLogError,
MeanAbsoluteError,
PeakSignalNoiseRatio,
)
from GANDLF.utils import get_image_from_tensor

Expand All @@ -25,7 +25,7 @@ def structural_similarity_index(target, prediction, mask=None) -> torch.Tensor:
torch.Tensor: The structural similarity index.
"""
ssim = StructuralSimilarityIndexMeasure(return_full_image=True)
_, ssim_idx_full_image = ssim(target, prediction)
_, ssim_idx_full_image = ssim(preds=prediction, target=target)
mask = torch.ones_like(ssim_idx_full_image) if mask is None else mask
try:
ssim_idx = ssim_idx_full_image[mask]
Expand All @@ -45,23 +45,30 @@ def mean_squared_error(target, prediction) -> torch.Tensor:
prediction (torch.Tensor): The prediction tensor.
"""
mse = MeanSquaredError()
return mse(target, prediction)
return mse(preds=prediction, target=target)


def peak_signal_noise_ratio(target, prediction) -> torch.Tensor:
def peak_signal_noise_ratio(target, prediction, data_range=None, epsilon=None) -> torch.Tensor:
"""
Computes the peak signal to noise ratio between the target and prediction.

Args:
target (torch.Tensor): The target tensor.
prediction (torch.Tensor): The prediction tensor.
data_range (float, optional): If not None, this data range is used as enumerator instead of computing it from the given data. Defaults to None.
epsilon (float, optional): If not None, this epsilon is added to the denominator of the fraction to avoid infinity as output. Defaults to None.
"""
mse = mean_squared_error(target, prediction)
return (
10.0
* torch.log10((torch.max(target) - torch.min(target)) ** 2)
/ (mse + sys.float_info.epsilon)
)

if epsilon == None:
psnr = PeakSignalNoiseRatio(data_range=data_range)
return psnr(preds=prediction, target=target)
else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0'
mse = mean_squared_error(target, prediction)
if data_range == None: #compute data_range like torchmetrics if not given
min_v = 0 if torch.min(target) > 0 else torch.min(target) #look at this line
max_v = torch.max(target)
data_range = max_v - min_v
return 10.0 * torch.log10((data_range ** 2) / (mse + epsilon))
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved


def mean_squared_log_error(target, prediction) -> torch.Tensor:
Expand All @@ -73,7 +80,7 @@ def mean_squared_log_error(target, prediction) -> torch.Tensor:
prediction (torch.Tensor): The prediction tensor.
"""
mle = MeanSquaredLogError()
return mle(target, prediction)
return mle(preds=prediction, target=target)


def mean_absolute_error(target, prediction) -> torch.Tensor:
Expand All @@ -85,7 +92,7 @@ def mean_absolute_error(target, prediction) -> torch.Tensor:
prediction (torch.Tensor): The prediction tensor.
"""
mae = MeanAbsoluteError()
return mae(target, prediction)
return mae(preds=prediction, target=target)


def _get_ncc_image(target, prediction) -> sitk.Image:
Expand Down
Loading