diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py index ed731ab..8b13849 100644 --- a/examples/images/cifar10/train_cifar10_ddp.py +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -15,7 +15,7 @@ from torch.utils.data import DistributedSampler from torchdyn.core import NeuralODE from torchvision import datasets, transforms -from utils_cifar import ema, generate_samples, infiniteloop, setup +from utils_cifar import ema, generate_samples, setup from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, @@ -102,8 +102,6 @@ def train(rank, total_num_gpus, argv): drop_last=True, ) - datalooper = infiniteloop(dataloader) - # Calculate number of epochs steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size) num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch)