Skip to content

Commit

Permalink
Add hetero_gnn forward_op test
Browse files Browse the repository at this point in the history
  • Loading branch information
zechengz committed Aug 5, 2021
1 parent 8b4bcff commit 77767ed
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/test_hetero_gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import math
import torch
import unittest

from torch import nn
from deepsnap.hetero_gnn import forward_op

class TestHeteroGNN(unittest.TestCase):

def test_hetero_gnn_forward(self):
xs = {}
layers = nn.ModuleDict()
emb_dim = 5
feat_dim = 10
num_samples = 8
keys = ['a', 'b', 'c']

for key in keys:
layers[key] = nn.Linear(feat_dim, emb_dim)
xs[key] = torch.ones(num_samples, feat_dim)

ys = forward_op(xs, layers)
for key in keys:
self.assertEqual(ys[key].shape[0], num_samples)
self.assertEqual(ys[key].shape[1], emb_dim)

if __name__ == "__main__":
unittest.main()

0 comments on commit 77767ed

Please sign in to comment.