-
Notifications
You must be signed in to change notification settings - Fork 0
/
sweep_final_models.py
executable file
·136 lines (123 loc) · 5.86 KB
/
sweep_final_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import git
import torch as ch
import torch
from tensorboardX import SummaryWriter
import argparse
from datasets.co_training import GTSingleStep
import json
import numpy as np
from datasets.datasets import COPRIOR_DATASETS
from models.model_utils import make_and_restore_model, load_model_from_checkpoint_dir
import os
from utils.logging import log_images_hook
from robustness.tools import folder
from cox.store import Store
def main(args):
store = Store(args.out_dir)
metadata_keys = {'dataset': str,
'indices_dir': str,
'use_gt': bool,
'dedup': bool,
'resample': bool,
'fraction': float,
'arch': str,
'epochs': int,
'lr': float,
'step_lr': int,
'step_lr_gamma': float,
'additional_transform': str,
'spurious': str}
store.add_table('metadata', metadata_keys)
args_dict = args.__dict__
store['metadata'].append_row({k: args_dict[k] for k in metadata_keys.keys()})
for mode in ['train', 'val']:
store.add_table(mode,
{'loss': float,
'acc': float,
'epoch': int})
store.add_table('main',
{'era': int,
'acc': float})
# MAKE DATASET AND LOADERS
dataset_name = args.dataset
data_path = args.data_path
ds_class = COPRIOR_DATASETS[dataset_name](data_path)
classes = ds_class.CLASS_NAMES
train_ds = ds_class.get_dataset('train')
val_ds = ds_class.get_dataset('test')
unlabelled_ds = ds_class.get_dataset('unlabeled')
main_outdir = os.path.join(args.out_dir, f'total')
os.makedirs(main_outdir, exist_ok=True)
main_writer = SummaryWriter(main_outdir)
for era in [1, 2, 5, 10, 15, 20]:
print("ERA", era)
indices_path = os.path.join(args.indices_dir, f'indices_era_{era}.pt')
labels_path = os.path.join(args.indices_dir, f'ys_era_{era}.pt')
if not os.path.exists(indices_path):
break
indices = ch.load(indices_path)
if args.use_gt:
labels=None
if args.dedup:
indices = np.unique(indices)
if args.resample: # TODO not fully kosher
indices = np.random.choice(np.arange(max(indices) + 1),
len(indices),
replace=False)
else:
labels = ch.load(labels_path)
out_dir = os.path.join(args.out_dir, f'era_{era}')
os.makedirs(out_dir, exist_ok=True)
model, model_args, checkpoint = make_and_restore_model(arch_name=args.arch, ds_class=ds_class,
resume_path=args.resume_path, train_args=args,
additional_transform=args.additional_transform,
out_dir=out_dir)
writer = SummaryWriter(out_dir)
gt_single_step = GTSingleStep(train_dataset=train_ds,
unlabelled_dataset=unlabelled_ds,
val_dataset=val_ds, store=store,
writer=writer, log_hook=log_images_hook,
start_fraction=args.fraction, indices=indices,
out_dir=out_dir,
labels=labels, spurious=args.spurious)
_, prec1 = gt_single_step.run(epochs=args.epochs, model=model, model_args=model_args,
out_dir=out_dir, val_iters=25, checkpoint_iters=50)
main_writer.add_scalar(f'total_val_prec1', prec1, era)
store['main'].append_row({'era': era, 'acc': prec1})
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, help='name of dataset')
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--out-dir', type=str, help='path to dump output')
parser.add_argument('--indices-dir', type=str, default=None)
parser.add_argument('--use-gt', action='store_true')
parser.add_argument('--dedup', action='store_true')
parser.add_argument('--resample', action='store_true')
parser.add_argument('--fraction', type=float, default=0.05)
parser.add_argument('--arch', type=str, help='name of model architecture')
parser.add_argument('--resume_path', type=str, default=None, help='path to load a previous checkpoint')
parser.add_argument('--epochs', type=int, default=400, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
parser.add_argument('--step_lr', type=int, default=50, help='epochs between LR drops')
parser.add_argument('--step_lr_gamma', type=float, default=0.1, help='LR drop multiplier')
parser.add_argument('--additional-transform', type=str, default='NONE', help='type of additional transform')
parser.add_argument('--eras', type=int, default=10)
parser.add_argument('--spurious', type=str, default=None,
help='add a spurious correlation to the'
'training dataset')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if os.path.exists('job_parameters.json'):
with open('job_parameters.json') as f:
job_params = json.load(f)
for k, v in job_params.items():
assert args.__contains__(k)
args.__setattr__(k, v)
os.makedirs(args.out_dir, exist_ok=True)
print(args.__dict__)
with open(os.path.join(args.out_dir, 'env.json'), 'w') as f:
json.dump(args.__dict__, f)
main(args)