Skip to content

Commit

Permalink
[GraphBolt][CUDA][Temporal] Tests and example enablement. (#7678)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Aug 9, 2024
1 parent 6d55515 commit 90c26be
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 59 deletions.
5 changes: 4 additions & 1 deletion examples/graphbolt/temporal_link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
shuffle=is_train,
)

if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=args.device)

############################################################################
# [Input]:
# 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending
Expand Down Expand Up @@ -250,7 +253,7 @@ def parse_args():
parser.add_argument(
"--mode",
default="cpu-cuda",
choices=["cpu-cpu", "cpu-cuda"],
choices=["cpu-cpu", "cpu-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,10 +830,6 @@ def test_in_subgraph_hetero():
)


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
Expand All @@ -848,6 +844,8 @@ def test_temporal_sample_neighbors_homo(
use_node_timestamp,
use_edge_timestamp,
):
if replace and F._default_context_str == "gpu":
pytest.skip("Sampling with replacement not yet implemented on the GPU.")
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
Expand All @@ -867,7 +865,7 @@ def test_temporal_sample_neighbors_homo(
assert len(indptr) == total_num_nodes + 1

# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())

# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([2])
Expand All @@ -878,15 +876,17 @@ def test_temporal_sample_neighbors_homo(
)

seed_list = [1, 3, 4]
seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)
seed_timestamp = torch.randint(
0, 100, (len(seed_list),), dtype=torch.int64, device=F.ctx()
)
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()
)
graph.edge_attributes = {"timestamp": edge_timestamp}

Expand Down Expand Up @@ -936,7 +936,7 @@ def _get_available_neighbors():
available_neighbors.append(neighbors)
return available_neighbors

nodes = torch.tensor(seed_list, dtype=indices_dtype)
nodes = torch.tensor(seed_list, dtype=indices_dtype, device=F.ctx())
subgraph = sampler(
nodes,
seed_timestamp,
Expand All @@ -947,6 +947,7 @@ def _get_available_neighbors():
)
sampled_count = torch.diff(subgraph.sampled_csc.indptr).tolist()
available_neighbors = _get_available_neighbors()
assert len(available_neighbors) == len(sampled_count)
for i, count in enumerate(sampled_count):
if not replace:
expect_count = min(fanouts[0], len(available_neighbors[i]))
Expand All @@ -958,10 +959,6 @@ def _get_available_neighbors():
assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
Expand All @@ -976,6 +973,8 @@ def test_temporal_sample_neighbors_hetero(
use_node_timestamp,
use_edge_timestamp,
):
if replace and F._default_context_str == "gpu":
pytest.skip("Sampling with replacement not yet implemented on the GPU.")
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
Expand Down Expand Up @@ -1006,7 +1005,7 @@ def test_temporal_sample_neighbors_hetero(
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
).to(F.ctx())

# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([-1, -1])
Expand All @@ -1017,26 +1016,26 @@ def test_temporal_sample_neighbors_hetero(
)

seeds = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
}
per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
"n2:e2:n1": torch.tensor([0], dtype=indices_dtype),
}

seed_timestamp = {
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
"n1": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),
}
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()
)
graph.edge_attributes = {"timestamp": edge_timestamp}

Expand Down
Loading

0 comments on commit 90c26be

Please sign in to comment.