-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_crosspoint.py
314 lines (245 loc) · 12.5 KB
/
train_crosspoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from __future__ import print_function
import os
import datetime
import torch
import numpy as np
import wandb
from lightly.loss.ntx_ent_loss import NTXentLoss
from sklearn.svm import SVC
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
# for distributed training
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from datasets.data import ShapeNetRender, ModelNet40SVM
from models.dgcnn import DGCNN, ResNet, DGCNN_partseg
from util import IOStream, AverageMeter
from parser import args
def _init_():
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists('checkpoints/'+args.exp_name):
os.makedirs('checkpoints/'+args.exp_name)
if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'):
os.makedirs('checkpoints/'+args.exp_name+'/'+'models')
os.environ["WANDB_BASE_URL"] = args.wb_url
wandb.login(key=args.wb_key)
def setup(rank):
# initialization for distibuted training on multiple GPUs
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
dist.init_process_group(args.backend, rank=rank, world_size=args.world_size)
def cleanup():
dist.destroy_process_group()
def train(rank):
if rank == 0:
wandb.init(project="CrossPoint", name=args.exp_name)
setup(rank)
io = IOStream('checkpoints/' + args.exp_name + '/run.log', rank=rank)
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_set = ShapeNetRender(transform, n_imgs = 2)
train_sampler = DistributedSampler(train_set, num_replicas=args.world_size, rank=rank)
samples_per_gpu = args.batch_size // args.world_size
train_loader = DataLoader(train_set,
sampler=train_sampler,
batch_size=samples_per_gpu,
shuffle=False,
num_workers=0,
pin_memory=True
)
# in DGCNN and DGCNN_partseg, args.rank is used to specify the device where get_graph_feature() are executed
args.rank = rank
#Try to load models
if args.model == 'dgcnn':
point_model = DGCNN(args).to(rank)
elif args.model == 'dgcnn_seg':
point_model = DGCNN_partseg(args).to(rank)
else:
raise Exception("Not implemented")
img_model = ResNet(resnet50(), feat_dim=2048)
img_model = img_model.to(rank)
point_model_ddp = DDP(point_model, device_ids=[rank], find_unused_parameters=True)
img_model_ddp = DDP(img_model, device_ids=[rank], find_unused_parameters=True)
if args.resume:
map_location = torch.device('cuda:%d' % rank)
point_model_ddp.load_state_dict(
torch.load(args.model_path, map_location=map_location)
)
img_model_ddp.load_state_dict(
torch.load(args.img_model_path, map_location=map_location)
)
io.cprint("Model Loaded !!")
# NOTE: combine parameters for different models
parameters = list(point_model_ddp.parameters()) + list(img_model_ddp.parameters())
if args.use_sgd:
io.cprint("Use SGD")
opt = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=1e-6)
else:
io.cprint("Use Adam")
opt = optim.Adam(parameters, lr=args.lr, weight_decay=1e-6)
lr_scheduler = CosineAnnealingLR(opt, T_max=args.epochs, eta_min=0, last_epoch=-1)
criterion = NTXentLoss(temperature = 0.1).to(rank)
best_acc = 0
for epoch in range(args.epochs):
####################
# Train
####################
if rank == 0:
wandb_log = {}
train_losses = AverageMeter()
train_imid_losses = AverageMeter()
train_cmid_losses = AverageMeter()
# require by DistributedSampler
train_sampler.set_epoch(epoch)
point_model.train()
img_model.train()
io.cprint(f'Start training epoch: ({epoch}/{args.epochs})')
for i, ((data_t1, data_t2), imgs) in enumerate(train_loader):
data_t1, data_t2, imgs = data_t1.to(rank), data_t2.to(rank), imgs.to(rank)
batch_size = data_t1.size()[0]
opt.zero_grad()
data = torch.cat((data_t1, data_t2))
data = data.transpose(2, 1).contiguous()
point_feats = point_model_ddp(data)[0]
img_feats = img_model_ddp(imgs)
point_t1_feats = point_feats[:batch_size, :]
point_t2_feats = point_feats[batch_size: , :]
loss_imid = criterion(point_t1_feats, point_t2_feats)
point_feats = torch.stack([point_t1_feats,point_t2_feats]).mean(dim=0)
loss_cmid = criterion(point_feats, img_feats)
total_loss = loss_imid + loss_cmid
total_loss.backward()
opt.step()
train_losses.update(total_loss.item(), batch_size)
train_imid_losses.update(loss_imid.item(), batch_size)
train_cmid_losses.update(loss_cmid.item(), batch_size)
if i % args.print_freq == 0:
time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
outstr = '[%s] Epoch (%d), Batch(%d/%d), loss: %.6f, imid loss: %.6f, cmid loss: %.6f ' \
% (time, epoch, i, len(train_loader), train_losses.avg, train_imid_losses.avg, train_cmid_losses.avg)
io.cprint(outstr)
# In PyTorch 1.1.0 and later, you should call lr_scheduler.step() after optimizer.step()
lr_scheduler.step()
""" Explanation of the function dist.all_gather_object(list1, train_imid_losses.avg):
list1: first parameter - a python list,
the length of list1 should be equavilent to the number of processes (world_size)
train_imid_losses.avg: second parameter - a python object,
all_gather_object() gather the values of the second parameter across all devices within the process group,
then broadcast these values into list1
e.g.: if you have 6 GPUs,
initialized list1: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
gathered values for train_imid_losses.avg through this function: 2.11, 4.14, 1.24, 4.48, 2.15, 3.13
after calling dist.all_gather_object(), list1 becomes [2.11, 4.14, 1.24, 4.48, 2.15, 3.13]
"""
list0 = [1.0 for _ in range(args.world_size)]
dist.all_gather_object(list0, train_losses.avg)
train_loss_avg = np.mean(list0)
list1 = [1.0 for _ in range(args.world_size)]
dist.all_gather_object(list1, train_imid_losses.avg)
train_imid_loss_avg = np.mean(list1)
list2 = [1.0 for _ in range(args.world_size)]
dist.all_gather_object(list2, train_cmid_losses.avg)
train_cmid_loss_avg = np.mean(list2)
outstr = 'Train %d, loss: %.6f, imid loss: %.6f, cmid loss: %.6f' % (epoch, train_loss_avg, train_imid_loss_avg, train_cmid_loss_avg)
io.cprint(outstr)
# Testing
train_val_loader = DataLoader(ModelNet40SVM(partition='train', num_points=1024), batch_size=args.test_batch_size, shuffle=True)
test_val_loader = DataLoader(ModelNet40SVM(partition='test', num_points=1024), batch_size=args.test_batch_size, shuffle=True)
feats_train = []
labels_train = []
point_model_ddp.eval()
for i, (data, label) in enumerate(train_val_loader):
labels = list(map(lambda x: x[0],label.numpy().tolist()))
data = data.permute(0, 2, 1).to(rank)
with torch.no_grad():
feats = point_model_ddp(data)[1]
feats = feats.detach().cpu().numpy()
for feat in feats:
feats_train.append(feat)
labels_train += labels
feats_train = np.array(feats_train)
labels_train = np.array(labels_train)
feats_test = []
labels_test = []
for i, (data, label) in enumerate(test_val_loader):
labels = list(map(lambda x: x[0],label.numpy().tolist()))
data = data.permute(0, 2, 1).to(rank)
with torch.no_grad():
feats = point_model_ddp(data)[1]
feats = feats.detach().cpu().numpy()
for feat in feats:
feats_test.append(feat)
labels_test += labels
feats_test = np.array(feats_test)
labels_test = np.array(labels_test)
io.cprint('Training SVM ...')
model_tl = SVC(C = 0.1, kernel ='linear')
model_tl.fit(feats_train, labels_train)
io.cprint('Testing SVM ...')
test_accuracy = model_tl.score(feats_test, labels_test)
overall_accuracy = [1.0 for _ in range(args.world_size)]
dist.all_gather_object(overall_accuracy, test_accuracy)
test_accuracy_avg = np.mean(overall_accuracy)
msg = f"Overall Linear Accuracy : {test_accuracy_avg}"
io.cprint(msg)
if rank == 0:
wandb_log['Train Loss'] = train_loss_avg
wandb_log['Train IMID Loss'] = train_imid_loss_avg
wandb_log['Train CMID Loss'] = train_cmid_loss_avg
wandb_log['Overall Linear Accuracy'] = test_accuracy_avg
wandb.log(wandb_log)
if test_accuracy_avg > best_acc:
best_acc = test_accuracy_avg
io.cprint('==> Saving Best Model...')
# For saving DDP model,
# refer https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/9
save_file = os.path.join(f'checkpoints/{args.exp_name}/models/',
'best_model.pth'.format(epoch=epoch))
torch.save(point_model_ddp.module.state_dict(), save_file)
save_img_model_file = os.path.join(f'checkpoints/{args.exp_name}/models/',
'img_model_best.pth')
torch.save(img_model_ddp.module.state_dict(), save_img_model_file)
if epoch % args.save_freq == 0:
io.cprint(f'==> Saving {epoch}_th model...')
save_file = os.path.join(f'checkpoints/{args.exp_name}/models/',
'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
torch.save(point_model_ddp.module.state_dict(), save_file)
save_img_model_file = os.path.join(f'checkpoints/{args.exp_name}/models/',
'img_model_{epoch}.pth'.format(epoch=epoch))
torch.save(img_model_ddp.module.state_dict(), save_img_model_file)
if rank == 0:
io.cprint('==> Saving Last Model...')
save_file = os.path.join(f'checkpoints/{args.exp_name}/models/',
'ckpt_epoch_last.pth')
torch.save(point_model_ddp.module.state_dict(), save_file)
save_img_model_file = os.path.join(f'checkpoints/{args.exp_name}/models/',
'img_model_last.pth')
torch.save(img_model_ddp.module.state_dict(), save_img_model_file)
# We should call wandb.finish() explicitly in multi processes training,
# otherwise wandb will hang in this process
wandb.finish()
io.close()
cleanup()
if __name__ == "__main__":
_init_()
io = IOStream('checkpoints/' + args.exp_name + '/run.log', rank=0)
io.cprint(str(args))
args.cuda = not args.no_cuda and torch.cuda.is_available() and torch.cuda.device_count() > 1
torch.manual_seed(args.seed)
if args.cuda:
io.cprint('CUDA is available! Using %d GPUs for DDP training' % args.world_size)
io.close()
torch.cuda.manual_seed(args.seed)
mp.spawn(train, nprocs=args.world_size)
else:
io.cprint('CUDA is unavailable! Exit')
io.close()