Skip to content

Commit

Permalink
[distGB] change test_mp_dataloader (#7819)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
CfromBU and Ubuntu authored Oct 16, 2024
1 parent 4a6bfa4 commit d92c98d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
21 changes: 21 additions & 0 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,27 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):
"part-{}".format(part_id) in part_metadata
), "part-{} does not exist".format(part_id)
part_files = part_metadata["part-{}".format(part_id)]

exist_dgl_graph = exist_graphbolt_graph = False
if os.path.exists(os.path.join(config_path, f"part{part_id}", "graph.dgl")):
use_graphbolt = False
exist_dgl_graph = True
if os.path.exists(
os.path.join(
config_path, f"part{part_id}", "fused_csc_sampling_graph.pt"
)
):
use_graphbolt = True
exist_graphbolt_graph = True

# Check if both DGL graph and GraphBolt graph exist or not exist. Make sure only one exists.
if not exist_dgl_graph and not exist_graphbolt_graph:
raise ValueError("The graph object doesn't exist.")
if exist_dgl_graph and exist_graphbolt_graph:
raise ValueError(
"Both DGL graph and GraphBolt graph exist. Please remove one."
)

if use_graphbolt:
part_graph_field = "part_graph_graphbolt"
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def test_dgl_partition_to_graphbolt_homo(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
os.remove(os.path.join(test_dir, f"part{part_id}/graph.dgl"))
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
Expand Down Expand Up @@ -1067,6 +1068,7 @@ def test_dgl_partition_to_graphbolt_hetero(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
os.remove(os.path.join(test_dir, f"part{part_id}/graph.dgl"))
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
Expand Down

0 comments on commit d92c98d

Please sign in to comment.