Skip to content

Commit

Permalink
Store/Load CIFAR from local/offline (#6390)
Browse files Browse the repository at this point in the history
CIFAR10_DATASET_PATH -> Path where the dataset is stored
STORE_CIFAR10        -> Store the dataset 1/0
CIFAR10_OFFLINE      -> To use offline dataset 1/0
MISC:
Added getDeviceId to get device if by name in case of accelerator

---------

Co-authored-by: Shaik Raza Sikander <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
5 people authored Aug 28, 2024
1 parent b5cf30a commit 9bc4cd0
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ def cifar_trainset(fp16=False):
dist.barrier()
if local_rank != 0:
dist.barrier()

data_root = os.getenv("TEST_DATA_DIR", "/tmp/")
trainset = torchvision.datasets.CIFAR10(root=os.path.join(data_root, "cifar10-data"),
train=True,
download=True,
transform=transform)
if os.getenv("CIFAR10_DATASET_PATH"):
data_root = os.getenv("CIFAR10_DATASET_PATH")
download = False
else:
data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data")
download = True
trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform)
if local_rank == 0:
dist.barrier()
return trainset
Expand Down

0 comments on commit 9bc4cd0

Please sign in to comment.