Skip to content

Commit

Permalink
test: add chunking tests for GraphTransformerMapperBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
japols committed Sep 26, 2024
1 parent 0da9a09 commit fa2dda3
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion tests/layers/block/test_block_graphtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,10 @@ def test_GraphTransformerMapperBlock_forward_backward(init, mapper_block):
edge_index = torch.randint(1, 10, (2, 10))
shapes = (10, 10, 10)
batch_size = 1
size = (10, 10)

# Forward pass
output, _ = block(x, edge_attr, edge_index, shapes, batch_size)
output, _ = block(x, edge_attr, edge_index, shapes, batch_size, size=size)

# Check output shape
assert output[0].shape == (10, out_channels)
Expand All @@ -327,3 +328,39 @@ def test_GraphTransformerMapperBlock_forward_backward(init, mapper_block):
assert (
param.grad.shape == param.shape
), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}"


def test_GraphTransformerMapperBlock_chunking(init, mapper_block):
(
in_channels,
_hidden_dim,
_out_channels,
edge_dim,
_bias,
_activation,
_num_heads,
_num_chunks,
) = init
# Initialize GraphTransformerMapperBlock
block = mapper_block

# Generate random input tensor
x = (torch.randn((10, in_channels)), torch.randn((10, in_channels)))
edge_attr = torch.randn((10, edge_dim))
edge_index = torch.randint(1, 10, (2, 10))
shapes = (10, 10, 10)
batch_size = 1
size = (10, 10)
num_chunks = torch.randint(2, 10, (1,)).item()

# result with chunks:
block.num_chunks = num_chunks
out_chunked, _ = block(x, edge_attr, edge_index, shapes, batch_size, size=size)
# result without chunks:
block.num_chunks = 1
out, _ = block(x, edge_attr, edge_index, shapes, batch_size, size=size)

assert out[0].shape == out_chunked[0].shape, f"out.shape ({out.shape}) != out_chunked.shape ({out_chunked.shape})"
assert out[1].shape == out_chunked[1].shape, f"out.shape ({out.shape}) != out_chunked.shape ({out_chunked.shape})"
assert torch.allclose(out[0], out_chunked[0], atol=1e-4), "out != out_chunked"
assert torch.allclose(out[1], out_chunked[1], atol=1e-4), "out != out_chunked"

0 comments on commit fa2dda3

Please sign in to comment.