diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 5ef93311fe55..3ddf25792ede 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -134,11 +134,11 @@ def create_dataloader( # [Output]: # A CopyTo object copying data in the datapipe to a specified device.\ ############################################################################ - if not args.cpu_sampling: + if args.storage_device != "cpu": datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"]) datapipe = datapipe.sample_neighbor(graph, args.fanout) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) - if args.cpu_sampling: + if args.storage_device == "cpu": datapipe = datapipe.copy_to(device) dataloader = gb.DataLoader(datapipe, args.num_workers) @@ -276,7 +276,7 @@ def run(rank, world_size, args, devices, dataset): ) # Pin the graph and features to enable GPU access. - if not args.cpu_sampling: + if args.storage_device == "pinned": dataset.graph.pin_memory_() dataset.feature.pin_memory_() @@ -388,15 +388,17 @@ def parse_args(): type=str, default="10,10,10", help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)" - " identical with the number of layers in your model. Default: 15,10,5", + " identical with the number of layers in your model. Default: 10,10,10", ) parser.add_argument( "--num-workers", type=int, default=0, help="The number of processes." ) parser.add_argument( - "--cpu-sampling", - action="store_true", - help="Disables GPU sampling and utilizes the CPU for dataloading.", + "--mode", + default="pinned-cuda", + choices=["cpu-cuda", "pinned-cuda"], + help="Dataset storage placement and Train device: 'cpu' for CPU and RAM," + " 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.", ) return parser.parse_args() @@ -406,6 +408,7 @@ def parse_args(): if not torch.cuda.is_available(): print(f"Multi-gpu training needs to be in gpu mode.") exit(0) + args.storage_device, _ = args.mode.split("-") devices = list(map(int, args.gpu.split(","))) world_size = len(devices)