Skip to content

Commit

Permalink
[FEA] Heterogeneous Distributed Sampling (#4795)
Browse files Browse the repository at this point in the history
Adds support for heterogeneous distributed sampling to the cuGraph distributed sampler.  Prerequisite for exposing this functionality to cuGraph-PyG.  Has been initially tested with cuGraph-PyG.

Updates the distributed sampler to use the new sampling API.

Merge after #4775, #4827, #4820

Closes #4773 
Closes #4401

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Joseph Nke (https://github.com/jnke2016)
  - Ralph Liu (https://github.com/nv-rliu)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4795
  • Loading branch information
alexbarghi-nv authored Jan 13, 2025
1 parent dd228f9 commit 8507cbf
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 274 deletions.
79 changes: 49 additions & 30 deletions python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -79,9 +79,15 @@ def get_reader(
return DistSampleReader(self._directory, format=self._format, rank=rank)

def __write_minibatches_coo(self, minibatch_dict):
has_edge_ids = minibatch_dict["edge_id"] is not None
has_edge_types = minibatch_dict["edge_type"] is not None
has_weights = minibatch_dict["weight"] is not None
has_edge_ids = (
"edge_id" in minibatch_dict and minibatch_dict["edge_id"] is not None
)
has_edge_types = (
"edge_type" in minibatch_dict and minibatch_dict["edge_type"] is not None
)
has_weights = (
"weight" in minibatch_dict and minibatch_dict["weight"] is not None
)

if minibatch_dict["renumber_map"] is None:
raise ValueError(
Expand All @@ -92,22 +98,22 @@ def __write_minibatches_coo(self, minibatch_dict):
if len(minibatch_dict["batch_id"]) == 0:
return

fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
minibatch_dict["batch_id"]
)
fanout_length = len(minibatch_dict["fanout"])
total_num_batches = (
len(minibatch_dict["label_hop_offsets"]) - 1
) / fanout_length

for p in range(
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
):
for p in range(0, int(ceil(total_num_batches / self.__batches_per_partition))):
partition_start = p * (self.__batches_per_partition)
partition_end = (p + 1) * (self.__batches_per_partition)

label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][
partition_start * fanout_length : partition_end * fanout_length + 1
]

batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]
num_batches_p = len(label_hop_offsets_array_p) - 1

start_batch_id = minibatch_dict["batch_start"]

input_offsets_p = minibatch_dict["input_offsets"][
partition_start : (partition_end + 1)
Expand Down Expand Up @@ -171,7 +177,7 @@ def __write_minibatches_coo(self, minibatch_dict):
}
)

end_batch_id = start_batch_id + len(batch_id_array_p) - 1
end_batch_id = start_batch_id + num_batches_p - 1
rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0

full_output_path = os.path.join(
Expand All @@ -188,9 +194,15 @@ def __write_minibatches_coo(self, minibatch_dict):
)

def __write_minibatches_csr(self, minibatch_dict):
has_edge_ids = minibatch_dict["edge_id"] is not None
has_edge_types = minibatch_dict["edge_type"] is not None
has_weights = minibatch_dict["weight"] is not None
has_edge_ids = (
"edge_id" in minibatch_dict and minibatch_dict["edge_id"] is not None
)
has_edge_types = (
"edge_type" in minibatch_dict and minibatch_dict["edge_type"] is not None
)
has_weights = (
"weight" in minibatch_dict and minibatch_dict["weight"] is not None
)

if minibatch_dict["renumber_map"] is None:
raise ValueError(
Expand All @@ -201,22 +213,22 @@ def __write_minibatches_csr(self, minibatch_dict):
if len(minibatch_dict["batch_id"]) == 0:
return

fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
minibatch_dict["batch_id"]
)
fanout_length = len(minibatch_dict["fanout"])
total_num_batches = (
len(minibatch_dict["label_hop_offsets"]) - 1
) / fanout_length

for p in range(
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
):
for p in range(0, int(ceil(total_num_batches / self.__batches_per_partition))):
partition_start = p * (self.__batches_per_partition)
partition_end = (p + 1) * (self.__batches_per_partition)

label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][
partition_start * fanout_length : partition_end * fanout_length + 1
]

batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]
num_batches_p = len(label_hop_offsets_array_p) - 1

start_batch_id = minibatch_dict["batch_start"]

input_offsets_p = minibatch_dict["input_offsets"][
partition_start : (partition_end + 1)
Expand Down Expand Up @@ -292,7 +304,7 @@ def __write_minibatches_csr(self, minibatch_dict):
}
)

end_batch_id = start_batch_id + len(batch_id_array_p) - 1
end_batch_id = start_batch_id + num_batches_p - 1
rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0

full_output_path = os.path.join(
Expand All @@ -309,12 +321,19 @@ def __write_minibatches_csr(self, minibatch_dict):
)

def write_minibatches(self, minibatch_dict):
if (minibatch_dict["majors"] is not None) and (
minibatch_dict["minors"] is not None
):
if "minors" not in minibatch_dict:
raise ValueError("invalid columns")

# PLC API specifies this behavior for empty input
# This needs to be handled here to avoid causing a hang
if len(minibatch_dict["minors"]) == 0:
return

if "majors" in minibatch_dict and minibatch_dict["majors"] is not None:
self.__write_minibatches_coo(minibatch_dict)
elif (minibatch_dict["major_offsets"] is not None) and (
minibatch_dict["minors"] is not None
elif (
"major_offsets" in minibatch_dict
and minibatch_dict["major_offsets"] is not None
):
self.__write_minibatches_csr(minibatch_dict)
else:
Expand Down
Loading

0 comments on commit 8507cbf

Please sign in to comment.