Skip to content

Commit

Permalink
Testing out c10d
Browse files Browse the repository at this point in the history
  • Loading branch information
bearpelican committed Aug 30, 2018
1 parent 2517c60 commit 7b1eb6b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 7 deletions.
26 changes: 25 additions & 1 deletion pytorch/launch_nv.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,27 @@
'--no-bn-wd',
'--num-tasks', 4,
'--ami-name', DEFAULT_PYTORCH_SOURCE,
'--env-name', 'pytorch_c10d'
]



# Current best settings 4x p3 - 34.5 minutes
lr = 0.50 * 4 # 4 = num tasks
scale_224 = 224/256
scale_288 = 128/256
c10d = [
'--phases', [
{'ep':0, 'sz':128, 'bs':256, 'trndir':'-sz/160',
'lr':lr*2}
],
'--num-tasks', 4,
'--ami-name', DEFAULT_PYTORCH_SOURCE,
'--env-name', 'pytorch_c10d',
'--c10d',
# '--dist-url', 'file:///home/ubuntu/data/file.sync', # single instances are faster with file sync
# '--dist-url', 'tcp://localhost:6006', # single instances are faster with file sync
# '--dist-url', 'env://',
]

# Current benchmark for 8x p3's - with Aspect Ratio Validation - Works right now for under 30 min (25:45, memory-eight.06, 25:03 sun-eight, 24:31 release-eight.02)
Expand Down Expand Up @@ -356,7 +376,11 @@ def start_training(job, params):
default_params = [
'~/data/imagenet',
'--fp16',
'--logdir', job.logdir
'--logdir', job.logdir,
'--dist-url', f'tcp://{world_0_ip}:6006', # single instances are faster with file sync
# '--dist-url', 'file:///home/ubuntu/data/file.sync', # single instances are faster with file sync
# '--dist-url', 'tcp://localhost:6006', # single instances are faster with file sync
# '--dist-url', 'env://',
]
if world_size > 1: default_params.append('--distributed')
training_args = default_params + params
Expand Down
12 changes: 12 additions & 0 deletions pytorch/training/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import os
from torch.nn.parallel import distributed_c10d

class DDP(DistributedDataParallel):
# Distributed wrapper. Supports asynchronous evaluation and model saving
Expand All @@ -15,7 +16,18 @@ def load_state_dict(self, *args, **kwargs):
def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)

class DDPC10d(distributed_c10d._DistributedDataParallelC10d):
# Distributed wrapper. Supports asynchronous evaluation and model saving
def forward(self, *args, **kwargs):
# DDP has a sync point on forward. No need to do this for eval. This allows us to have different batch sizes
if self.training: return super().forward(*args, **kwargs)
else: return self.module(*args, **kwargs)

def load_state_dict(self, *args, **kwargs):
self.module.load_state_dict(*args, **kwargs)

def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)


def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size()
Expand Down
44 changes: 38 additions & 6 deletions pytorch/training/train_imagenet_nv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.autograd import Variable
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
Expand Down Expand Up @@ -49,6 +48,7 @@ def get_parser():
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode. Default True')
parser.add_argument('--c10d', action='store_true', help='Run model c10d mode. Default True')
parser.add_argument('--loss-scale', type=float, default=1024,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--distributed', action='store_true', help='Run distributed training. Default True')
Expand All @@ -75,6 +75,14 @@ def get_parser():
cudnn.benchmark = True
args = get_parser().parse_args()

if args.c10d:
assert(args.distributed)
import torch.distributed.c10d as dist
# from torch.distributed import c10d
from torch.nn.parallel import distributed_c10d
elif args.distributed:
import torch.distributed as dist

# Only want master rank logging to tensorboard
is_master = (not args.distributed) or (dist_utils.env_rank()==0)
is_rank0 = args.local_rank == 0
Expand All @@ -91,15 +99,26 @@ def main():
if args.distributed:
log.console('Distributed initializing process group')
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size())
dist_url = args.dist_url

if args.c10d and (('file:///' in dist_url) or ('tcp://' in dist_url)):
dist_url = args.dist_url+f'?rank={dist_utils.env_rank()}&world_size={dist_utils.env_world_size()}'
dist.init_process_group(backend=args.dist_backend, init_method=dist_url, world_size=dist_utils.env_world_size())
assert(dist_utils.env_world_size() == dist.get_world_size())
log.console("Distributed: success (%d/%d)"%(args.local_rank, dist.get_world_size()))

log.console('After distributed - test tensor creation works')
tt = torch.tensor([1]).float().cuda()

log.console("Loading model")
model = resnet.resnet50(bn0=args.init_bn0).cuda()
if args.fp16: model = network_to_half(model)
if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

if args.c10d:
model = dist_utils.DDPC10d(model, device_ids=[args.local_rank], output_device=args.local_rank)
# model = distributed_c10d._DistributedDataParallelC10d(model, device_ids=[args.local_rank], output_device=args.local_rank)
c10d_sanity_check()
elif args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
best_top5 = 93 # only save models over 93%. Otherwise it stops to save every time

global model_params, master_params
Expand Down Expand Up @@ -132,7 +151,7 @@ def main():

if args.distributed:
log.console('Syncing machines before training')
dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())
sum_tensor(torch.tensor([1.0]).float().cuda())

log.event("~~epoch\thours\ttop1\ttop5\n")
for epoch in range(args.start_epoch, scheduler.tot_epochs):
Expand All @@ -152,6 +171,12 @@ def main():
if phase: save_checkpoint(epoch, model, best_top5, optimizer, filename=f'sz{phase["bs"]}_checkpoint.path.tar')


def c10d_sanity_check():
log.console('Sanity check to make sure tensor creation works')
tt = torch.tensor([1]).float().cuda()
log.console('Currently deadlock here')
log.console('Woot able to reduce tensor')

def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
net_meter = NetworkMeter()
timer = TimeMeter()
Expand Down Expand Up @@ -192,7 +217,7 @@ def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
reduced_loss, batch_total = to_python_float(loss.data), to_python_float(input.size(0))
if args.distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
metrics = torch.tensor([batch_total, reduced_loss, corr1, corr5]).float().cuda()
batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
batch_total, reduced_loss, corr1, corr5 = sum_tensor(metrics).cpu().numpy()
reduced_loss = reduced_loss/dist_utils.env_world_size()
top1acc = to_python_float(corr1)*(100.0/batch_total)
top5acc = to_python_float(corr5)*(100.0/batch_total)
Expand Down Expand Up @@ -279,7 +304,7 @@ def distributed_predict(input, target, model, criterion):
corr1, corr5 = correct(output.data, target, topk=(1, 5))

metrics = torch.tensor([batch_size, valid_batches, loss, corr1, corr5]).float().cuda()
batch_total, valid_batches, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
batch_total, valid_batches, reduced_loss, corr1, corr5 = sum_tensor(metrics).cpu().numpy()
reduced_loss = reduced_loss/valid_batches

top1 = corr1*(100.0/batch_total)
Expand Down Expand Up @@ -389,6 +414,13 @@ def update_lr(self, epoch, batch_num, batch_tot):
tb.log("sizes/lr", lr)
tb.log("sizes/momentum", args.momentum)


def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size()
def sum_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
return rt

# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if isinstance(t, (float, int)): return t
Expand Down

0 comments on commit 7b1eb6b

Please sign in to comment.