From 77767edeebbca53006780a2c6e51196710d1bf5a Mon Sep 17 00:00:00 2001 From: Zecheng Zhang Date: Thu, 5 Aug 2021 00:12:25 -0700 Subject: [PATCH] Add hetero_gnn forward_op test --- tests/test_hetero_gnn.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/test_hetero_gnn.py diff --git a/tests/test_hetero_gnn.py b/tests/test_hetero_gnn.py new file mode 100644 index 0000000..5fdafaa --- /dev/null +++ b/tests/test_hetero_gnn.py @@ -0,0 +1,28 @@ +import math +import torch +import unittest + +from torch import nn +from deepsnap.hetero_gnn import forward_op + +class TestHeteroGNN(unittest.TestCase): + + def test_hetero_gnn_forward(self): + xs = {} + layers = nn.ModuleDict() + emb_dim = 5 + feat_dim = 10 + num_samples = 8 + keys = ['a', 'b', 'c'] + + for key in keys: + layers[key] = nn.Linear(feat_dim, emb_dim) + xs[key] = torch.ones(num_samples, feat_dim) + + ys = forward_op(xs, layers) + for key in keys: + self.assertEqual(ys[key].shape[0], num_samples) + self.assertEqual(ys[key].shape[1], emb_dim) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file