Skip to content

Commit

Permalink
[GraphBolt] Fix hetero sampling bug with single fanout. (#7719)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Aug 18, 2024
1 parent fc29d0e commit 2ce0ea0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,10 @@ auto GetPickFn(
type_per_edge.value(), probs_or_mask, args, picked_data_ptr,
seed_offset, subgraph_indptr_ptr, etype_id_to_num_picked_offset);
} else {
picked_data_ptr += subgraph_indptr_ptr[seed_offset];
int64_t num_sampled = Pick(
offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,
args, picked_data_ptr + subgraph_indptr_ptr[seed_offset]);
args, picked_data_ptr);
if (type_per_edge) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,21 @@ def test_sample_neighbors_homo(
assert subgraph.original_row_node_ids is None


@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero_single_fanout(labor):
u, i = torch.randint(20, size=(1000,)), torch.randint(10, size=(1000,))
graph = dgl.heterograph({("u", "w", "i"): (u, i), ("i", "b", "u"): (i, u)})

graph = gb.from_dglgraph(graph).to(F.ctx())

sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors

for i in range(11):
nodes = {"u": torch.randint(10, (100,), device=F.ctx())}
sampler(nodes, fanouts=torch.tensor([-1]))
# Should reach here without crashing.


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
Expand Down

0 comments on commit 2ce0ea0

Please sign in to comment.