From d92c98d927d6ef0c0bcb7475194495f6f417e091 Mon Sep 17 00:00:00 2001 From: Wenxuan Cao <90617523+CfromBU@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:52:59 +0800 Subject: [PATCH] [distGB] change test_mp_dataloader (#7819) Co-authored-by: Ubuntu --- python/dgl/distributed/partition.py | 21 +++++++++++++++++++++ tests/distributed/test_partition.py | 2 ++ 2 files changed, 23 insertions(+) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index de782a5f144d..f74da1cf9685 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -334,6 +334,27 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False): "part-{}".format(part_id) in part_metadata ), "part-{} does not exist".format(part_id) part_files = part_metadata["part-{}".format(part_id)] + + exist_dgl_graph = exist_graphbolt_graph = False + if os.path.exists(os.path.join(config_path, f"part{part_id}", "graph.dgl")): + use_graphbolt = False + exist_dgl_graph = True + if os.path.exists( + os.path.join( + config_path, f"part{part_id}", "fused_csc_sampling_graph.pt" + ) + ): + use_graphbolt = True + exist_graphbolt_graph = True + + # Check if both DGL graph and GraphBolt graph exist or not exist. Make sure only one exists. + if not exist_dgl_graph and not exist_graphbolt_graph: + raise ValueError("The graph object doesn't exist.") + if exist_dgl_graph and exist_graphbolt_graph: + raise ValueError( + "Both DGL graph and GraphBolt graph exist. Please remove one." + ) + if use_graphbolt: part_graph_field = "part_graph_graphbolt" else: diff --git a/tests/distributed/test_partition.py b/tests/distributed/test_partition.py index ef075841e25b..a46e0b778367 100644 --- a/tests/distributed/test_partition.py +++ b/tests/distributed/test_partition.py @@ -990,6 +990,7 @@ def test_dgl_partition_to_graphbolt_homo( orig_g = dgl.load_graphs( os.path.join(test_dir, f"part{part_id}/graph.dgl") )[0][0] + os.remove(os.path.join(test_dir, f"part{part_id}/graph.dgl")) new_g = load_partition( part_config, part_id, load_feats=False, use_graphbolt=True )[0] @@ -1067,6 +1068,7 @@ def test_dgl_partition_to_graphbolt_hetero( orig_g = dgl.load_graphs( os.path.join(test_dir, f"part{part_id}/graph.dgl") )[0][0] + os.remove(os.path.join(test_dir, f"part{part_id}/graph.dgl")) new_g = load_partition( part_config, part_id, load_feats=False, use_graphbolt=True )[0]