Skip to content

Commit

Permalink
replace mean by nanmean etc meta_evaluation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem authored Apr 17, 2024
1 parent f8779f8 commit 60b2457
Showing 1 changed file with 43 additions and 51 deletions.
94 changes: 43 additions & 51 deletions metaquantus/meta_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from .perturbation_tests.base import PerturbationTestBase
from .helpers.sanity_checks import sanity_analysis, sanity_analysis_under_perturbation
import numpy as np


class MetaEvaluation:
Expand Down Expand Up @@ -130,7 +131,6 @@ def __call__(
channel_first: Optional[bool] = True,
softmax: Optional[bool] = False,
device: Optional[str] = None,
batch_size: Optional[int] = 64,
model_predict_kwargs: Optional[Dict[str, Any]] = {},
score_direction: Optional[str] = None,
):
Expand All @@ -155,8 +155,6 @@ def __call__(
Indicates if channels is first.
softmax: bool
Indicates if the softmax (or logits) are used.
batch_size: int
The batch size to run Quantus evaluation with.
device: torch.device
The device used, to enable GPUs.
model_predict_kwargs: dict
Expand All @@ -182,7 +180,6 @@ def __call__(
channel_first=channel_first,
softmax=softmax,
device=device,
batch_size=batch_size,
)

# Run inference.
Expand Down Expand Up @@ -259,7 +256,6 @@ def run_perturbation_analysis(
s_batch: Union[np.array, None] = None,
channel_first: Optional[bool] = True,
softmax: Optional[bool] = False,
batch_size: Optional[int] = 64,
device: Optional[str] = None,
model_predict_kwargs: Optional[Dict[str, Any]] = {},
):
Expand Down Expand Up @@ -357,7 +353,6 @@ def run_perturbation_analysis(
},
model_predict_kwargs=model_predict_kwargs,
softmax=softmax,
batch_size=batch_size,
device=device,
)

Expand Down Expand Up @@ -403,7 +398,6 @@ def run_perturbation_analysis(
model_predict_kwargs=model_predict_kwargs,
softmax=softmax,
device=device,
batch_size=batch_size,
)

self.results_eval_scores_perturbed[test_name][i] = scores_perturbed
Expand Down Expand Up @@ -585,50 +579,48 @@ def run_meta_consistency_analysis(self) -> dict:
shape
)

# Get the mean scores, over the right axes.
consistency_scores = {
"IAC_{NR}": self.results_consistency_scores[perturbation_type][
"intra_scores_res"
].mean(axis=(0, 2)),
"IAC_{AR}": self.results_consistency_scores[perturbation_type][
"intra_scores_adv"
].mean(axis=(0, 2)),
"IEC_{NR}": self.results_consistency_scores[perturbation_type][
"inter_scores_res"
].mean(axis=1),
"IEC_{AR}": self.results_consistency_scores[perturbation_type][
"inter_scores_adv"
].mean(axis=1),
}

# Compute the results.
consistency_results = {
"IAC_{NR} mean": consistency_scores["IAC_{NR}"].mean(),
"IAC_{NR} std": consistency_scores["IAC_{NR}"].std(),
"IAC_{AR} mean": consistency_scores["IAC_{AR}"].mean(),
"IAC_{AR} std": consistency_scores["IAC_{NR}"].std(),
"IEC_{NR} mean": consistency_scores["IEC_{NR}"].mean(),
"IEC_{NR} std": consistency_scores["IEC_{NR}"].std(),
"IEC_{AR} mean": consistency_scores["IEC_{AR}"].mean(),
"IEC_{AR} std": consistency_scores["IEC_{AR}"].std(),
}

# Produce the results.
shape = (4, self.iterations)
self.results_meta_consistency_scores[perturbation_type] = {
"consistency_scores": consistency_scores,
"consistency_results": consistency_results,
"MC_means": np.array(list(consistency_scores.values()))
.reshape(shape)
.mean(axis=0),
"MC_mean": np.array(list(consistency_scores.values()))
.reshape(shape)
.mean(),
"MC_std": np.array(list(consistency_scores.values()))
.reshape(shape)
.mean(axis=0)
.std(),
}
# ...

# Get the mean scores, over the right axes.
consistency_scores = {
"IAC_{NR}": np.nanmean(self.results_consistency_scores[perturbation_type][
"intra_scores_res"
], axis=(0, 2)),
"IAC_{AR}": np.nanmean(self.results_consistency_scores[perturbation_type][
"intra_scores_adv"
], axis=(0, 2)),
"IEC_{NR}": np.nanmean(self.results_consistency_scores[perturbation_type][
"inter_scores_res"
], axis=1),
"IEC_{AR}": np.nanmean(self.results_consistency_scores[perturbation_type][
"inter_scores_adv"
], axis=1),
}

# Compute the results.
consistency_results = {
"IAC_{NR} mean": np.nanmean(consistency_scores["IAC_{NR}"]),
"IAC_{NR} std": np.nanstd(consistency_scores["IAC_{NR}"]),
"IAC_{AR} mean": np.nanmean(consistency_scores["IAC_{AR}"]),
"IAC_{AR} std": np.nanstd(consistency_scores["IAC_{AR}"]),
"IEC_{NR} mean": np.nanmean(consistency_scores["IEC_{NR}"]),
"IEC_{NR} std": np.nanstd(consistency_scores["IEC_{NR}"]),
"IEC_{AR} mean": np.nanmean(consistency_scores["IEC_{AR}"]),
"IEC_{AR} std": np.nanstd(consistency_scores["IEC_{AR}"]),
}

# Produce the results.
shape = (4, self.iterations)
self.results_meta_consistency_scores[perturbation_type] = {
"consistency_scores": consistency_scores,
"consistency_results": consistency_results,
"MC_means": np.nanmean(np.array(list(consistency_scores.values()))
.reshape(shape), axis=0),
"MC_mean": np.nanmean(np.array(list(consistency_scores.values()))
.reshape(shape)),
"MC_std": np.nanstd(np.nanmean(np.array(list(consistency_scores.values()))
.reshape(shape), axis=0)),
}
if self.print_results:
print(
f"\n{perturbation_type} Perturbation Test ---> MC score="
Expand Down

0 comments on commit 60b2457

Please sign in to comment.