diff --git a/main.py b/main.py index 7eeb7fd..a7b5d2f 100644 --- a/main.py +++ b/main.py @@ -233,12 +233,11 @@ def batchify(data, bsz): # if args.cuda and (not args.single) and (torch.cuda.device_count() > 1): # model.module.rnn.flatten_parameters() # else: + if args.cuda and (not args.single) and (torch.cuda.device_count() > 1): + # Scatters minibatches (in dim=1) across available GPUs + model = nn.DataParallel(model, dim=1) if isinstance(model, torch.nn.DataParallel): model = model.module - elif args.cuda: - if (not args.single) and (torch.cuda.device_count() > 1): - # Scatters minibatches (in dim=1) across available GPUs - model = nn.DataParallel(model, dim=1) model.rnn.flatten_parameters() criterion = nn.CrossEntropyLoss()