diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index 86a3e39..6952f34 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -19,31 +19,3 @@ def test_reconstruction(batch_shape, channels, height, width): assert torch.allclose(x_reconstructed, x) assert torch.allclose(log_det_forward, torch.zeros_like(log_det_forward)) assert torch.allclose(log_det_forward, log_det_inverse) - - -def test_efficient_forward(): - x = torch.tensor([ - [1, 2, 5, 6], - [3, 4, 7, 8], - [9, 10, 13, 14], - [11, 12, 15, 16] - ])[None, None] - layer = Squeeze(event_shape=x.shape[-3:]) - z, log_det_forward = layer.forward(x) - z2, log_det_forward2 = layer.forward2(x) - assert torch.allclose(z, z2) - - -def test_efficient_inverse(): - x = torch.tensor([ - [1, 2, 5, 6], - [3, 4, 7, 8], - [9, 10, 13, 14], - [11, 12, 15, 16] - ])[None, None] - layer = Squeeze(event_shape=x.shape[-3:]) - z, log_det_forward = layer.forward(x) - - xr, _ = layer.inverse2(z) - - assert torch.allclose(x, xr)