diff --git a/python/dgl/graphbolt/negative_sampler.py b/python/dgl/graphbolt/negative_sampler.py index 3cef15da83e6..2d43d9ac7f9b 100644 --- a/python/dgl/graphbolt/negative_sampler.py +++ b/python/dgl/graphbolt/negative_sampler.py @@ -82,7 +82,9 @@ def _sample(self, minibatch): minibatch.seeds[etype], minibatch.labels[etype], minibatch.indexes[etype], - ) = self._sample_with_etype(pos_pairs, use_seeds=True) + ) = self._sample_with_etype( + pos_pairs, etype, use_seeds=True + ) else: ( minibatch.seeds, diff --git a/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py b/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py index f13278a0a9bf..44aab2d8b8bb 100644 --- a/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py @@ -293,7 +293,23 @@ def test_NegativeSampler_Hetero_Data(): ), } ) - - item_sampler = gb.ItemSampler(itemset, batch_size=2) - negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) + batch_size = 2 + negative_ratio = 1 + item_sampler = gb.ItemSampler(itemset, batch_size=batch_size) + negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio) assert len(list(negative_dp)) == 5 + # Perform negative sampling. + expected_neg_src = [ + {"n1:e1:n2": torch.tensor([0, 0])}, + {"n1:e1:n2": torch.tensor([1, 1])}, + {"n2:e2:n1": torch.tensor([0, 0])}, + {"n2:e2:n1": torch.tensor([1, 1])}, + {"n2:e2:n1": torch.tensor([2, 2])}, + ] + for i, data in enumerate(negative_dp): + # Check negative seeds value. + for etype, seeds_data in data.seeds.items(): + neg_src = seeds_data[batch_size:, 0] + neg_dst = seeds_data[batch_size:, 1] + assert torch.equal(expected_neg_src[i][etype], neg_src) + assert (neg_dst < 3).all(), neg_dst