-
Notifications
You must be signed in to change notification settings - Fork 8
/
linear_eval.py
124 lines (88 loc) · 3.62 KB
/
linear_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm import tqdm
from arguments import get_args
from augmentations import get_aug
from models import get_model, get_backbone
from tools import AverageMeter
from datasets import get_dataset
from optimizers import get_optimizer, LR_Scheduler
def main(args):
train_loader = torch.utils.data.DataLoader(
dataset=get_dataset(
transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
train=True,
**args.dataset_kwargs
),
batch_size=args.eval.batch_size,
shuffle=True,
**args.dataloader_kwargs
)
test_loader = torch.utils.data.DataLoader(
dataset=get_dataset(
transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
train=False,
**args.dataset_kwargs
),
batch_size=args.eval.batch_size,
shuffle=False,
**args.dataloader_kwargs
)
model = get_backbone(args.model.backbone)
classifier = nn.Linear(in_features=model.output_dim, out_features=16, bias=True).to(args.device)
assert args.eval_from is not None
save_dict = torch.load(args.eval_from, map_location='cpu')
msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
# print(msg)
model = model.to(args.device)
model = torch.nn.DataParallel(model)
# if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
classifier = torch.nn.DataParallel(classifier)
# define optimizer
optimizer = get_optimizer(
args.eval.optimizer.name, classifier,
lr=args.eval.base_lr*args.eval.batch_size/256,
momentum=args.eval.optimizer.momentum,
weight_decay=args.eval.optimizer.weight_decay)
# define lr scheduler
lr_scheduler = LR_Scheduler(
optimizer,
args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256,
args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256,
len(train_loader),
)
loss_meter = AverageMeter(name='Loss')
acc_meter = AverageMeter(name='Accuracy')
# Start training
global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
for epoch in global_progress:
loss_meter.reset()
model.eval()
classifier.train()
local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)
for idx, (images, labels) in enumerate(local_progress):
classifier.zero_grad()
with torch.no_grad():
feature = model(images.to(args.device))
preds = classifier(feature)
loss = F.cross_entropy(preds, labels.to(args.device))
loss.backward()
optimizer.step()
loss_meter.update(loss.item())
lr = lr_scheduler.step()
local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
classifier.eval()
correct, total = 0, 0
acc_meter.reset()
for idx, (images, labels) in enumerate(test_loader):
with torch.no_grad():
feature = model(images.to(args.device))
preds = classifier(feature).argmax(dim=1)
correct = (preds == labels.to(args.device)).sum().item()
acc_meter.update(correct/preds.shape[0])
print(f'Accuracy = {acc_meter.avg*100:.2f}')
if __name__ == "__main__":
main(args=get_args())