-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
312 lines (255 loc) · 10.4 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import argparse
import itertools
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
from ilc import get_grads as ilc_grads
from ilc import get_train_loader
import utils
def setup_dataset(dataset: str, groups, root_dir: str):
dataset = get_dataset(dataset=dataset, root_dir=root_dir)
train = dataset.get_subset(
"train",
transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]),
)
val = dataset.get_subset(
"val", transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
)
grouper = CombinatorialGrouper(dataset, groups)
return train, val, grouper
def init_model(model_name="resnet18", pretrained=False, model_cache="", device="cuda:0"):
if model_cache != "":
torch.hub.set_dir(model_cache)
model = getattr(models, model_name)(pretrained=pretrained)
# replace the last layer with a random-init layer of the output size we want (1)
# We can't use num_classes=1 in the model factory above, because pretrained models don't come in
# that size so it won't know how to init the last layer.
if model_name == "resnet50":
model.fc = nn.Linear(2048, 1)
elif model_name.startswith("resnet"):
model.fc = nn.Linear(512, 1)
elif model_name.startswith("vit_b"):
model.heads[-1] = nn.Linear(768, 1)
elif model_name.startswith("vit_l"):
model.heads[-1] = nn.Linear(1024, 1)
model.train()
return model.to(device)
def train(config: utils.TrainConfig, model, trainset, valset, recorder, device="cuda:0"):
# set up optimization
bce_loss = nn.BCELoss(reduction="none")
optim = torch.optim.SGD(
model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay
)
sched = StepLR(optim, step_size=config.lr_step)
# data loading
trainloader = get_train_loader(
"standard", trainset, batch_size=config.batch_size, drop_last=config.objective == "ilc"
)
valloader = get_eval_loader(
"standard", valset, batch_size=config.batch_size, drop_last=config.objective == "ilc"
)
# look for checkpoints
ep = recorder.latest_checkpoint(model, optim, sched, device)
if ep == 0:
print("starting with fresh model")
# run evaluation once at the beginning
validation(model, valloader, 0, bce_loss, recorder, device)
else:
print("starting from checkpointed epoch", ep)
# train
model.train()
t = tqdm(initial=ep * len(trainloader), total=config.epochs * len(trainloader), mininterval=10)
while ep < config.epochs:
print("starting epoch", ep + 1)
for i, (x, y, metadata) in enumerate(trainloader):
it = ep * len(trainloader) + i
train_iteration(
config, model, optim, x, y, metadata, ep, it, bce_loss, recorder, device
)
t.update()
sched.step()
ep += 1
validation(model, valloader, ep, bce_loss, recorder, device)
recorder.checkpoint(ep, model, optim, sched)
recorder.close()
t.close()
def train_iteration(config, model, optim, x, y, metadata, epoch, it, loss, recorder, device):
z = recorder.grouper.metadata_to_group(metadata)
optim.zero_grad()
# This is left over from supporting grad accumulation, which I removed because it won't work
# with ILC as we've got it currently, and the GPUs are big enough for the other experiments
# anyway.
actual_batch = len(x) // config.grad_accum_factor
# collect group losses/counts across entire batch
agg_loss = 0.0
agg_penalty = 0.0
losses = [0.0] * recorder.n_groups
counts = [0] * recorder.n_groups
for idx in range(0, len(x), actual_batch):
ex = x[idx : idx + actual_batch].to(device)
why = y[idx : idx + actual_batch].float().to(device)
zee = z[idx : idx + actual_batch].to(device)
logits = torch.sigmoid(model(ex))
# accumulate group losses/counts (this is just straight-up loss, not necessarily the
# training objective)
batch_loss = loss(logits, why.unsqueeze(-1))
agg_loss += torch.sum(batch_loss).item()
for i in range(recorder.n_groups):
mask = zee.eq(i).unsqueeze(-1)
losses[i] += torch.sum(batch_loss * mask).item()
counts[i] += torch.sum(mask).item()
# This is the training objective.
p = penalty(config, loss, logits, why, batch_loss, epoch, optim)
if config.objective != "ilc":
# We don't call backward if it's ILC because the penalty function actually sets
# the gradients itself.
p.backward()
agg_penalty += p.item()
recorder.report_train(it, agg_loss, losses, counts, agg_penalty)
optim.step()
def penalty(
config: utils.TrainConfig, loss: nn.Module, logits, y, batch_loss, epoch: int, optim
) -> torch.Tensor:
"""Returns the training loss for the objective specified by the config."""
if config.objective == "erm":
return torch.sum(batch_loss) / config.batch_size
if config.objective == "irm":
scale = logits.new_tensor(1.0, requires_grad=True) # place on same device as logits
l = loss(logits * scale, y.unsqueeze(-1))
grad = torch.autograd.grad(l.mean(), [scale], create_graph=True)[0]
p = torch.sum(grad**2)
l = torch.sum(l) / config.batch_size
# see https://github.com/facebookresearch/InvariantRiskMinimization/blob/main/code/colored_mnist/main.py#L145
p_weight = config.irm_weight if epoch >= config.irm_anneal else 1.0
l += p_weight * p
if p_weight > 1.0:
# keep gradients in a reasonable range
l /= p_weight
return l
if config.objective == "ilc":
l, _ = ilc_grads(
agreement_threshold=config.ilc_agreement_threshold,
batch_size=1,
loss_fn=loss,
n_agreement_envs=len(y),
params=optim.param_groups[0]["params"],
output=logits,
target=y,
method="and_mask",
scale_grad_inverse_sparsity=1.0,
)
return l
raise NotImplementedError(f"unrecognized penalty type {config.objective}")
def validation(model, valloader, epoch, loss, recorder, device="cuda:0"):
print("running validation for epoch", epoch)
with torch.no_grad():
model.eval()
losses = [0.0] * recorder.n_groups
accs = [0.0] * recorder.n_groups
for x, y, metadata in valloader:
x = x.to(device)
y = y.to(device)
z = recorder.grouper.metadata_to_group(metadata).to(device)
logits = torch.sigmoid(model(x))
batch_loss = loss(logits, y.unsqueeze(-1).float())
batch_preds = logits >= 0.5
acc = batch_preds.squeeze(-1) == y
for i in range(recorder.n_groups):
mask = z.eq(i)
losses[i] += torch.sum(batch_loss.squeeze(-1) * mask).detach().item()
accs[i] += torch.logical_and(acc, mask).sum().detach().item()
recorder.report_valid(epoch, losses, accs)
model.train()
def main(exp, wilds_dir, model_cache, tensorboard_dir, checkpoint_dir):
# This part of the file is meant to be edited to add the sorts of experiments you'd like to run.
# I'm not a fan of computing the experiment parameters from the SLURM array ID in bash. So we'll
# define a list of experiments here and let the SLURM array ID index them.
exps = list(
itertools.product(
["celebA", "waterbirds"],
["erm", "irm", ("ilc", 0.1), ("ilc", 0.25), ("ilc", 0.5)],
["resnet18", "resnet34", "resnet50", "vit_b_32", "vit_l_32"],
[True, False], # pretrained
)
)
# unpack ILC-specific param so we don't duplicate non-ILC exps
for i, e in enumerate(exps):
if e[1][0] == "ilc":
e = e[0], "ilc", e[2], e[3], e[1][1]
else:
e = *e, 0.0
exps[i] = e
if exp == -1:
for i, e in enumerate(exps):
print(f"{i}: {e}")
return
# select the experiment according to the SLURM array ID
exps = exps[exp : exp + 1]
print(f"running experiment {exp}: {exps[0]}")
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
print("training with", device)
for ds, obj, m, pre, agree in exps:
config = utils.TrainConfig(
dataset=ds,
model_name=m,
pretrained=pre,
epochs=30 if ds == "celebA" else 100,
objective=obj,
ilc_agreement_threshold=agree,
lr=0.0005 if pre else 0.01,
)
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
trainset, valset, grouper = setup_dataset(
config.dataset, config.dataset_groups(), wilds_dir
)
model = init_model(config.model_name, config.pretrained, model_cache, device)
recorder = utils.Recorder(config, valset, grouper, tensorboard_dir, checkpoint_dir)
train(config, model, trainset, valset, recorder, device)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"exp",
metavar="EXP",
type=int,
default=-1,
nargs="?",
help="experiment number to run. If -1, instead prints the experiment combinations.",
)
parser.add_argument(
"--wilds", type=str, default="./data/", help="path where WILDS data can be stored"
)
parser.add_argument(
"--model-cache", type=str, default="", help="override the PyTorch model cache"
)
parser.add_argument(
"--tensorboard", type=str, default="./runs/", help="path where Tensorboard output is stored"
)
parser.add_argument(
"--checkpoints",
type=str,
default="./checkpoints/",
help="path where checkpointed models are stored",
)
args = parser.parse_args()
main(
args.exp,
*(
os.path.expanduser(p)
for p in [args.wilds, args.model_cache, args.tensorboard, args.checkpoints]
),
)