Skip to content

Commit

Permalink
Added tests for the subset_in_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 15, 2023
1 parent f95872d commit 8ad1822
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
8 changes: 3 additions & 5 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,7 @@ def _parse_subset_in_dim(
if subset_in_dim == in_dim:
subset_idx = None
else:
subset_idx = torch.stack(
[torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]
).unsqueeze(-2)
subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)])

return subset_in_dim, subset_idx

Expand Down Expand Up @@ -721,8 +719,8 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
# Subset the input features for each MLP in the ensemble
if self.subset_idx is not None:
if len(h.shape) != 2:
assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}"
h = h[..., self.subset_idx]
assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}."
h = h[..., self.subset_idx].transpose(-2, -3)

# Run the standard forward pass
h = super().forward(h)
Expand Down
63 changes: 63 additions & 0 deletions tests/test_ensemble_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,37 @@ def check_ensemble_feedforwardnn_mean(
individual_output = torch.stack(individual_outputs, dim=-3).mean(dim=-3).detach().numpy()
np.testing.assert_allclose(ensemble_output, individual_output, atol=1e-5, err_msg=msg)

def check_ensemble_feedforwardnn_simple(
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
batch_size: int,
more_batch_dim: int,
last_layer_is_readout=False,
**kwargs,
):
msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"

# Create EnsembleFeedForwardNN instance
hidden_dims = [17, 17, 17]
ensemble_mlp = EnsembleFeedForwardNN(
in_dim,
out_dim,
hidden_dims,
num_ensemble,
reduction=None,
last_layer_is_readout=last_layer_is_readout,
**kwargs,
)

# Test with a sample input
input_tensor = torch.randn(batch_size, in_dim)
ensemble_output = ensemble_mlp(input_tensor)

# Check for the output shape
self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg)

def test_ensemble_feedforwardnn(self):
# more_batch_dim=0
self.check_ensemble_feedforwardnn(
Expand Down Expand Up @@ -516,6 +547,38 @@ def test_ensemble_feedforwardnn(self):
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True
)

# Test `subset_in_dim`
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=0.5
)
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=0.5
)
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=0.5
)
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=7
)
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=7
)
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=7
)
with self.assertRaises(AssertionError):
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=1.5
)
with self.assertRaises(AssertionError):
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=39
)
with self.assertRaises(AssertionError):
self.check_ensemble_feedforwardnn_simple(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=39
)


if __name__ == "__main__":
ut.main()

0 comments on commit 8ad1822

Please sign in to comment.