-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
178 lines (143 loc) · 6.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import math
import time
import numpy as np
class ModelWrapper(torch.nn.Module):
def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None):
super(ModelWrapper, self).__init__()
self.model = model
self.feature_dim = feature_dim
self.classification_head = torch.nn.Linear(feature_dim, num_classes)
self.normalize = normalize
if initial_weights is None:
initial_weights = torch.zeros_like(self.classification_head.weight)
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
self.classification_head.bias = torch.nn.Parameter(
torch.zeros_like(self.classification_head.bias))
# Note: modified. Get rid of the language part.
if hasattr(self.model, 'transformer'):
delattr(self.model, 'transformer')
def change_classifier(self, num_classes, initial_weights=None):
self.classification_head = torch.nn.Linear(self.feature_dim, num_classes)
if initial_weights is None:
initial_weights = torch.zeros_like(self.classification_head.weight)
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
self.classification_head.bias = torch.nn.Parameter(
torch.zeros_like(self.classification_head.bias))
def forward(self, images, return_features=False):
features = self.model.encode_image(images)
if self.normalize:
features = features / features.norm(dim=-1, keepdim=True)
logits = self.classification_head(features)
if return_features:
return logits, features
return logits
def get_model_from_sd(state_dict, base_model):
if not 'classification_head.weight' in state_dict :
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
state_dict = new_state_dict
feature_dim = state_dict['classification_head.weight'].shape[1]
num_classes = state_dict['classification_head.weight'].shape[0]
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
for p in model.parameters():
p.data = p.data.float()
model.load_state_dict(state_dict)
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
return torch.nn.DataParallel(model, device_ids=devices)
def get_model_from_sd_modified(state_dict, base_model, num_classes_, initial_weights=None):
if not 'classification_head.weight' in state_dict :
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
state_dict = new_state_dict
feature_dim = state_dict['classification_head.weight'].shape[1]
num_classes = state_dict['classification_head.weight'].shape[0]
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
for p in model.parameters():
p.data = p.data.float()
model.load_state_dict(state_dict)
model.change_classifier(num_classes_, initial_weights)
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
return torch.nn.DataParallel(model, device_ids=devices)
def maybe_dictionarize_batch(batch):
if isinstance(batch, dict):
return batch
if len(batch) == 2:
return {'images': batch[0], 'labels': batch[1]}
elif len(batch) == 3:
return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
else:
raise ValueError(f'Unexpected number of elements: {len(batch)}')
def test_model_on_dataset(model, dataset):
model.eval()
device = 'cuda'
with torch.no_grad():
top1, correct, n = 0., 0., 0.
end = time.time()
loader = dataset.test_loader
if type(dataset).__name__ == 'ImageNet2p':
loader = dataset.train_loader
# assert to make sure the imagenet held-out minival logic is consistent across machines.
# tested on a few machines but if this fails for you please submit an issue and we will resolve.
assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG')
for i, batch in enumerate(loader):
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].cuda(), batch['labels'].cuda()
data_time = time.time() - end
y = labels
if 'image_paths' in batch:
image_paths = batch['image_paths']
logits = model(inputs)
projection_fn = getattr(dataset, 'project_logits', None)
if projection_fn is not None:
logits = projection_fn(logits, device)
if hasattr(dataset, 'project_labels'):
y = dataset.project_labels(y, device)
if isinstance(logits, list):
logits = logits[0]
pred = logits.argmax(dim=1, keepdim=True).to(device)
if hasattr(dataset, 'accuracy'):
acc1, num_total = dataset.accuracy(logits, y, image_paths, None)
correct += acc1
n += num_total
else:
correct += pred.eq(y.view_as(pred)).sum().item()
n += y.size(0)
batch_time = time.time() - end
end = time.time()
if i % 20 == 0:
percent_complete = 100.0 * i / len(loader)
print(
f"[{percent_complete:.0f}% {i}/{len(loader)}]\t"
f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
)
top1 = correct / n
return top1
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster