-
Notifications
You must be signed in to change notification settings - Fork 0
/
ipw.py
executable file
·162 lines (118 loc) · 5.58 KB
/
ipw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Dict, Type, Optional, Any, List, Tuple
import torch
import torch.nn as nn
import pytorch_lightning as pl
from arl import Learner
class IPW(pl.LightningModule):
"""Feed forward neural network with modified BCE loss, based on inverse
probability weighting of the losses.
Attributes:
config: Dict with hyperparameters (learning rate, batch size).
num_features: Dimensionality of the data input.
group_probs: Empirical observation probabilities of the different
protected groups.
hidden_units: Number of hidden units in each layer of the network.
optimizer: Optimizer used to update the model parameters.
sensitive_label: Option to use joint probability of label and group
membership for computing the weights.
opt_kwargs: Optional; optimizer keywords other than learning rate.
"""
def __init__(self,
config: Dict[str, Any],
num_features: int,
group_probs: torch.Tensor,
hidden_units: List[int] = [64,32],
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adagrad,
sensitive_label: bool = False,
opt_kwargs: Dict[str, Any] = {},
):
"""Inits an instance of IPW with the given attributes."""
super().__init__()
# save params EXCEPT group_probs since that throws an error
self.save_hyperparameters('config', 'num_features', 'hidden_units', 'optimizer', 'sensitive_label', 'opt_kwargs')
#self.save_hyperparameters()
self.hparams.group_probs = group_probs
# save group probabilities
self.group_probs = group_probs
# init networks
self.learner = Learner(input_shape=num_features, hidden_units=hidden_units)
# init loss function
self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Returns and logs the loss on the training set.
Args:
batch: Inputs, labels and group memberships of a data batch.
batch_idx: Index of batch in the dataset (not needed).
optimizer_idx: Index of the optimizer that is used for updating the
weights after the training step; 0 = learner, 1 = adversary.
"""
x, y, s = batch
loss = self.learner_step(x, y, s)
# logging
self.log("training/loss", loss)
return loss
def learner_step(self, x: torch.Tensor, y: torch.Tensor, s: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Computes the inversely reweighted or unweighted BCE loss.
Args:
x: Tensor of shape [batch_size, num_features] with data inputs.
y: Tensor of shape [batch_size] with labels.
s: Optional; tensor of shape [batch_size] with group indices.
Returns:
One of the following:
The mean of single BCE losses that are reweighted with the inverse
of the joint probabilities of labels and group memberships.
The mean of single BCE losses that are reweighted with the inverse
of the group probabilities.
The unweighted BCE loss.
"""
# compute unweighted bce
logits = self.learner(x)
bce = self.loss_fct(logits, y)
# consider both s and y for selecting probability
if s is not None:
# compute weights
if self.hparams.sensitive_label:
sample_weights = torch.index_select(torch.index_select(self.group_probs.to(self.device), 0, s), 1, y.long())
else:
sample_weights = torch.index_select(self.group_probs.to(self.device), 0, s)
# compute reweighted loss
loss = torch.mean(bce / sample_weights)
else:
# compute unweighted loss
loss = torch.mean(bce)
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
"""Computes and logs the validation loss.
Args:
batch: Inputs, labels and group memberships of a data batch.
batch_idx: Index of batch in the dataset (not needed).
"""
x, y, _ = batch
loss = self.learner_step(x, y)
# logging
self.log("validation/loss", loss)
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
"""Computes and logs the test loss.
Args:
batch: Inputs, labels and group memberships of a data batch.
batch_idx: Index of batch in the dataset (not needed).
"""
x, y, _ = batch
loss = self.learner_step(x, y)
# logging
self.log("test/loss", loss)
def configure_optimizers(self):
"""Chooses optimizer and learning-rate to use during optimization.
Returns:
Optimizer.
"""
optimizer = self.hparams.optimizer(self.learner.parameters(), lr=self.hparams.config['lr'], **self.hparams.opt_kwargs)
return optimizer
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward propagation of inputs through the network.
Args:
input: Tensor of shape [batch_size, num_features] with data inputs.
Returns:
Tensor of shape [batch_size] with predicted logits.
"""
return self.learner(x)