From 078361c28d7c768c63065fb5967088821b45712e Mon Sep 17 00:00:00 2001 From: Kilian Date: Wed, 13 Dec 2023 15:38:15 -0500 Subject: [PATCH] add warning parallel training print --- examples/cifar10/train_cifar10.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index 921aac8..24f88e7 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -105,6 +105,9 @@ def train(argv): optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) if FLAGS.parallel: + print( + "Warning: parallel training is performing slighlty worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." + ) net_model = torch.nn.DataParallel(net_model) ema_model = torch.nn.DataParallel(ema_model)