-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils_uncertainty.py
400 lines (318 loc) · 12.7 KB
/
utils_uncertainty.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
#!/usr/bin/env python
"""
Description: This file contains the implementation of the following uncertainty metrics:
- Expected Calibration Error (ECE)
- Entropy
- Area under the Receiving Operator Characteristic (ROC) curve (AUROC)
- Kullback-Leibler Divergence (KLD)
Source:
https://github.com/prs-eth/FILM-Ensemble/blob/main/utils_uncertainty.py
Turkoglu, M. O., Becker, A., Gündüz, H. A., Rezaei, M., Bischl, B., Daudt, R. C.,
D'Aronco, S., Wegner, J. D., & Schindler, K. (2022). FiLM-Ensemble: Probabilistic
Deep Learning via Feature-wise Linear Modulation. In Advances in Neural Information Processing Systems.
"""
### IMPORTS ##
# Built-in imports
from typing import Union
import pathlib
# Lib imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
# Custom imports
import const
### CLASS DEFINITIONS ###
class _ECELoss(nn.Module):
"""
Calculates the Expected Calibration Error of a model.
(This isn't necessary for temperature scaling, just a cool metric).
The input to this loss is the logits of a model, NOT the softmax scores.
This divides the confidence outputs into equally-sized interval bins.
In each bin, we compute the confidence gap:
bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
We then return a weighted average of the gaps, based on the number
of samples in each bin
See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
"Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
2015.
"""
def __init__(self, n_bins=10):
"""
n_bins (int): number of confidence interval bins
"""
super(_ECELoss, self).__init__()
self.bin_boundaries = torch.linspace(0, 1, n_bins + 1)
self.bin_lowers = self.bin_boundaries[:-1]
self.bin_uppers = self.bin_boundaries[1:]
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
plot: bool = False,
file_name: Union[str, pathlib.Path] = None,
threshold: float = None
):
"""
Calculate the Expected Calibration Error (ECE) and plot the reliability diagram if needed
Parameters
----------
logits: torch.Tensor
The logits of the model
labels: torch.Tensor
The labels of the data
plot: bool, optional
Whether to plot the reliability diagram
Default: False
file_name: str, optional
The name of the file to save the plot to
Must be provided if plot is True
Default: None
Returns
-------
ece: torch.Tensor
The Expected Calibration Error
accs: List
List of accuracies per bin
confs: List
List of confidences per bin
avg_acc: float
The samplewise average accuracy
avg_conf: float
The samplewise average confidence
"""
softmaxes = F.softmax(logits, dim=1)
confidences, predictions = torch.max(softmaxes, 1)
accuracies = predictions.eq(labels)
avg_conf = confidences.float().mean().item()
avg_acc = accuracies.float().mean().item()
ece = torch.zeros(1, device=logits.device)
accs = list()
accs_all = list()
confs = list()
confs_all = list()
counts = list()
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
# Calculated |confidence - accuracy| in each bin
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
counts.append(in_bin.sum().item())
prop_in_bin = in_bin.float().mean()
if prop_in_bin.item() > 0:
accuracy_in_bin = accuracies[in_bin].float().mean()
avg_confidence_in_bin = confidences[in_bin].mean()
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
accs.append(accuracy_in_bin)
accs_all.append(accuracy_in_bin.item())
confs.append(avg_confidence_in_bin)
confs_all.append(avg_confidence_in_bin.item())
else:
accs_all.append(0)
confs_all.append(0)
counts = np.array(counts)
if plot:
self.plot_reliability_diagram(accs_all, confs_all, counts, avg_conf, avg_acc, file_name)
return ece, accs, confs, avg_acc, avg_conf
def plot_reliability_diagram(
self,
accs: list,
confs: list,
counts: np.ndarray,
avg_conf: float,
avg_acc: float,
file_name: str
) -> None:
"""
Plots the reliability diagram
Parameters
----------
accs: List
List of accuracies
confs: List
List of confidences
counts: np.ndarray
Number of elements per bin
avg_conf: float
Average confidence across all samples
avg_acc: float
Average accuracy across all samples
file_name: pathlib.Path
The name of the file to save the plot to
Added by M. Halbheer and D. Mühlematter.
"""
# Convert to numpy arrays
accs, confs = np.array(accs), np.array(confs)
# Set the font sizes
title_fontsize = 24
label_fontsize = 20
tick_fontsize = 18
legend_fontsize = 18
# Create the figure
fig, ax0 = plt.subplots(1, 1, sharex="col", figsize=(10, 8))
# Calculate the bin size and the positions of the bars
bin_size = 1 / len(counts)
positions = self.bin_boundaries.numpy()[:-1] + bin_size / 2
# Define the width of the bars
width = bin_size
# Plot the accuracy bars
ax0.bar(positions, width=width, edgecolor='black', height=accs,
color='lightgray', label='Accuracy', linewidth=2)
# Plot the gap to a well calibrated model
gap_positions = positions[counts > 0]
gap_accs = accs[counts > 0]
ax0.bar(gap_positions, height=np.abs(gap_accs - gap_positions),
bottom=np.minimum(gap_accs, gap_positions), width=width,
edgecolor='darkred', color='red', alpha=0.4, linewidth=1, label="Gap")
# Plot the diagonal (perfectly calibrated model)
ax0.set_aspect("equal", adjustable='box')
ax0.plot([0, 1], [0, 1], linestyle="--", color="gray")
# Set the axis limits
ax0.set_xlim(0, 1)
ax0.set_ylim(0, 1)
# Format the axes
ax0.set_ylabel("Accuracy", fontsize=label_fontsize)
ax0.set_yticks(np.linspace(0, 1, 11))
ax0.tick_params(axis='y', which='major', labelsize=tick_fontsize)
# Enable grid
ax0.grid()
# Set the title
ax0.set_title("Reliability Diagram", fontsize=title_fontsize)
# Calculate the percentage of samples in each bin
percentage_counts = counts / np.sum(counts)
# Conversion functions for the secondary y-axis
def percentage_to_norm(x):
return x / 100
def norm_to_percentage(x):
return x * 100
# Secondary y-axis for the percentage of samples in each bin
secax = ax0.secondary_yaxis('right', functions=(norm_to_percentage, percentage_to_norm))
secax.set_yticks(np.linspace(0, 100, 11))
secax.set_ylabel('% of samples', fontsize=label_fontsize)
secax.tick_params(axis='y', which='major', labelsize=tick_fontsize)
# Plot the histogram of the confidences
ax0.bar(positions, 0, bottom=percentage_counts, width=0.9*width,
edgecolor="darkslategray", color="darkslategray", alpha=1.0, linewidth=3,
label="% of samples in bin")
# Format the axes
ax0.tick_params(axis='both', which='major', labelsize=tick_fontsize)
# Add vertical lines for the average accuracy and confidence
ax0.axvline(x=avg_conf, color='darkgoldenrod', linestyle='--', label=f'Avg. Confidence {avg_conf * 100:.1f}%')
ax0.axvline(x=avg_acc, color='darkgoldenrod', ls='solid', label=f'Avg. Accuracy {avg_acc * 100:.1f}%')
# Add a legend
ax0.legend(fontsize=legend_fontsize)
# Format the entire plot layout
plt.tight_layout()
plt.subplots_adjust(hspace=-0.02)
# Create the plot directory if it does not exist
if not const.PLOT_DIR.exists():
const.PLOT_DIR.mkdir(parents=True)
# Set absolute plot size
fig.set_size_inches(8, 8)
# Save the plot
png_name = f"{file_name}.png"
pdf_name = f"{file_name}.pdf"
plt.savefig(const.PLOT_DIR.joinpath(pdf_name), bbox_inches='tight')
plt.savefig(const.PLOT_DIR.joinpath(png_name), bbox_inches='tight')
plt.close()
class Entropy(nn.Module):
"""
Calculates the entropy of the distribution and means over batch dimension
"""
def __init__(self, softmax=True):
super(Entropy, self).__init__()
self.softmax = softmax
def forward(self, logits):
if self.softmax:
logits = F.softmax(logits, dim=1)
entropy = Categorical(logits=logits).entropy().mean()
return entropy
class AUROC(nn.Module):
"""
Calculates the AUROC
(Area under the Receiving Operator Characteristic (ROC) curve)
for out-of-distribution (OOD) detection
"""
def __init__(self, softmax=True, equal_size=True):
super(AUROC, self).__init__()
self.softmax = softmax
self.equal_size = equal_size
def forward(self, id_logits, ood_logits):
if self.softmax:
id_logits = F.softmax(id_logits, dim=1)
ood_logits = F.softmax(ood_logits, dim=1)
if self.equal_size:
min_size = np.min((id_logits.shape[0], ood_logits.shape[0]))
id_logits = id_logits[:min_size,...]
ood_logits = ood_logits[:min_size,...]
id_conf_scores, _ = torch.max(id_logits, dim=1, keepdim=False)
ood_conf_scores, _ = torch.max(ood_logits, dim=1, keepdim=False)
id_targets = torch.ones_like(id_conf_scores)
od_targets = torch.zeros_like(ood_conf_scores)
y_pred = torch.cat((id_conf_scores, ood_conf_scores), dim=0).cpu().data.numpy()
y_target = torch.cat((id_targets, od_targets), dim=0).cpu().data.numpy()
score = roc_auc_score(y_target, y_pred)
return score
class function_space_analysis(nn.Module):
"""
Calculates the disagreement and distance between two models.
As distance measure the Jensen-Shannon divergence is used, potentially
the square root is taken.
Paramters:
----------
w_softmax: bool
apply softmax to the input logits
square_root: bool
take the square root of the Jensen-Shannon divergence
Returns:
--------
disagreement: float
disagreement between the two models between 0 and 1
distance: float
distance between the two models
"""
def __init__(self, w_softmax=True, square_root=True):
super(function_space_analysis, self).__init__()
self.w_softmax = w_softmax
self.square_root = square_root
self.lossFn = nn.KLDivLoss(reduction='sum', log_target=True)
def forward(self, logits_1, logits_2):
# check if input is tensor
if not torch.is_tensor(logits_1):
logits_1 = torch.tensor(logits_1)
logits_2 = torch.tensor(logits_2)
# apply softmax if necessary
if self.w_softmax:
logits_1 = F.log_softmax(logits_1, dim=1)
logits_2 = F.log_softmax(logits_2, dim=1)
# Jensen-Shannon divergence / symmetric Kullback-Leibler divergence
distance = 0.5 *(self.lossFn(logits_1,logits_2) + self.lossFn(logits_2,logits_1))
# square root if necessary
if self.square_root:
distance = torch.sqrt(distance)
# get predictions
pred_1 = logits_1.max(1, keepdim=True)[1]
pred_2 = logits_2.max(1, keepdim=True)[1]
# calculate disagreement
disagreement = torch.sum(pred_1 != pred_2) # /pred_2.shape[0]
return disagreement, distance
def precision(predictions, labels):
"""
Calculate the precision of the model
Parameters
----------
predictions: torch.Tensor
The predictions of the model
labels: torch.Tensor
The labels of the data
Returns
-------
prec: float
The precision of the model
"""
n = predictions.shape[0]
pred_classes = predictions.sum(dim=0)
true_positives = (predictions * labels).sum(dim=0)
prec = (true_positives / pred_classes).sum(dim=1)
return prec