diff --git a/tests/layers/block/test_block_graphtransformer.py b/tests/layers/block/test_block_graphtransformer.py index 00af258..9826314 100644 --- a/tests/layers/block/test_block_graphtransformer.py +++ b/tests/layers/block/test_block_graphtransformer.py @@ -303,9 +303,10 @@ def test_GraphTransformerMapperBlock_forward_backward(init, mapper_block): edge_index = torch.randint(1, 10, (2, 10)) shapes = (10, 10, 10) batch_size = 1 + size = (10, 10) # Forward pass - output, _ = block(x, edge_attr, edge_index, shapes, batch_size) + output, _ = block(x, edge_attr, edge_index, shapes, batch_size, size=size) # Check output shape assert output[0].shape == (10, out_channels) @@ -327,3 +328,39 @@ def test_GraphTransformerMapperBlock_forward_backward(init, mapper_block): assert ( param.grad.shape == param.shape ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" + + +def test_GraphTransformerMapperBlock_chunking(init, mapper_block): + ( + in_channels, + _hidden_dim, + _out_channels, + edge_dim, + _bias, + _activation, + _num_heads, + _num_chunks, + ) = init + # Initialize GraphTransformerMapperBlock + block = mapper_block + + # Generate random input tensor + x = (torch.randn((10, in_channels)), torch.randn((10, in_channels))) + edge_attr = torch.randn((10, edge_dim)) + edge_index = torch.randint(1, 10, (2, 10)) + shapes = (10, 10, 10) + batch_size = 1 + size = (10, 10) + num_chunks = torch.randint(2, 10, (1,)).item() + + # result with chunks: + block.num_chunks = num_chunks + out_chunked, _ = block(x, edge_attr, edge_index, shapes, batch_size, size=size) + # result without chunks: + block.num_chunks = 1 + out, _ = block(x, edge_attr, edge_index, shapes, batch_size, size=size) + + assert out[0].shape == out_chunked[0].shape, f"out.shape ({out.shape}) != out_chunked.shape ({out_chunked.shape})" + assert out[1].shape == out_chunked[1].shape, f"out.shape ({out.shape}) != out_chunked.shape ({out_chunked.shape})" + assert torch.allclose(out[0], out_chunked[0], atol=1e-4), "out != out_chunked" + assert torch.allclose(out[1], out_chunked[1], atol=1e-4), "out != out_chunked"