diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c62ab84..170c694 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -10,11 +10,15 @@ simple_networkx_graph, simple_networkx_small_graph, simple_networkx_graph_alphabet, + simple_networkx_dense_graph, + simple_networkx_dense_multigraph, simple_networkx_multigraph, generate_dense_hete_dataset, generate_simple_small_hete_graph, + generate_simple_dense_hete_graph, + generate_simple_dense_hete_multigraph, generate_dense_hete_multigraph, - gen_graph, + gen_graph ) @@ -2463,6 +2467,80 @@ def test_secure_split(self): self.assertEqual(num_val, len(split_res[1])) self.assertEqual(num_test, len(split_res[2])) + def test_negative_sampling_edge_case_heterogeneous(self): + # complete graph + G = generate_simple_dense_hete_graph() + graph = HeteroGraph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + self.assertRaises(ValueError, dataset[0]._create_neg_sampling, 1) + + # complete graph except 1 missing edge + G = generate_simple_dense_hete_graph(num_edges_removed=1) + graph = HeteroGraph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + dataset[0]._create_neg_sampling(1) + for message_type in dataset[0].message_types: + num_edges = dataset[0].num_edges(message_type) + self.assertEqual( + dataset[0].edge_label[message_type].shape[0], + 2 * num_edges + ) + + # complete multigraph + G = generate_simple_dense_hete_multigraph() + graph = HeteroGraph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + self.assertRaises(ValueError, dataset[0]._create_neg_sampling, 1) + + # complete multigraph except 1 missing edge + G = generate_simple_dense_hete_multigraph(num_edges_removed=1) + graph = HeteroGraph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + dataset[0]._create_neg_sampling(1) + for message_type in dataset[0].message_types: + num_edges = dataset[0].num_edges(message_type) + self.assertEqual( + dataset[0].edge_label[message_type].shape[0], + 2 * num_edges + ) + + def test_negative_sampling_edge_case(self): + # complete graph + G = simple_networkx_dense_graph() + graph = Graph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + self.assertRaises(ValueError, dataset[0]._create_neg_sampling, 1) + + # complete graph except 1 missing edge + G = simple_networkx_dense_graph(num_edges_removed=1) + graph = Graph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + num_edges = dataset.num_edges[0] + dataset[0]._create_neg_sampling(1) + self.assertEqual(dataset[0].edge_label.shape[0], 2 * num_edges) + + # complete multigraph + G = simple_networkx_dense_multigraph() + graph = Graph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + self.assertRaises(ValueError, dataset[0]._create_neg_sampling, 1) + + # complete multigraph except 1 missing edge + G = simple_networkx_dense_multigraph(num_edges_removed=1) + graph = Graph(G) + graphs = [graph] + dataset = GraphDataset(graphs, task="link_pred") + num_edges = dataset.num_edges[0] + dataset[0]._create_neg_sampling(1) + self.assertEqual(dataset[0].edge_label.shape[0], 2 * num_edges) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_dataset_tensor.py b/tests/test_dataset_tensor.py index f57ebb8..02d37a2 100644 --- a/tests/test_dataset_tensor.py +++ b/tests/test_dataset_tensor.py @@ -15,7 +15,7 @@ simple_networkx_multigraph, generate_dense_hete_dataset, generate_simple_small_hete_graph, - gen_graph, + gen_graph )