diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index 921aac8..24f88e7 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -105,6 +105,9 @@ def train(argv): optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) if FLAGS.parallel: + print( + "Warning: parallel training is performing slighlty worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." + ) net_model = torch.nn.DataParallel(net_model) ema_model = torch.nn.DataParallel(ema_model)