diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index aa7b844..a3a8d21 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -13,6 +13,65 @@ from normalizing_flows.utils import get_batch_shape +class FactoredBijection(Bijection): + """ + Factored bijection class. + + Partitions the input tensor x into parts x_A and x_B, then applies a bijection to x_A independently of x_B while + keeping x_B identical. + """ + + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + transformed_event_shape: Union[torch.Size, Tuple[int, ...]], + small_bijection: Bijection, + transformed_event_mask: torch.Tensor, + **kwargs): + """ + + :param event_shape: shape of input event x. + :param transformed_event_shape: shape of transformed event x_A. + :param constant_event_shape: shape of constant event x_B. + :param small_bijection: bijection applied to transformed event x_A. + :param transformed_event_mask: boolean mask that selects which elements of event x correspond to the transformed + event x_A. + :param kwargs: + """ + super().__init__(event_shape, **kwargs) + + # Check that shapes are correct + event_size = torch.prod(torch.as_tensor(event_shape)) + transformed_event_size = torch.prod(torch.as_tensor(transformed_event_shape)) + assert event_size >= transformed_event_size + + assert transformed_event_mask.shape == event_shape + assert small_bijection.event_shape == transformed_event_shape + + self.transformed_event_mask = transformed_event_mask + self.transformed_event_shape = transformed_event_shape + self.small_bijection = small_bijection + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(x, self.event_shape) + transformed, log_det = self.small_bijection.forward( + x[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + context + ) + out = x.clone() + out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(z, self.event_shape) + transformed, log_det = self.small_bijection.inverse( + z[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + context + ) + out = z.clone() + out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) + return out, log_det + + class ConvNetConditioner(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, diff --git a/test/test_factored_bijection.py b/test/test_factored_bijection.py new file mode 100644 index 0000000..6df7726 --- /dev/null +++ b/test/test_factored_bijection.py @@ -0,0 +1,42 @@ +import torch +from normalizing_flows.bijections.finite.multiscale.base import FactoredBijection +from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine + + +def test_basic(): + torch.manual_seed(0) + + bijection = FactoredBijection( + event_shape=(6, 6), + transformed_event_shape=(3, 3), + transformed_event_mask=torch.tensor([ + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]), + small_bijection=ElementwiseAffine(event_shape=(3, 3)) + ) + + x = torch.randn(100, *bijection.event_shape) + z, log_det_forward = bijection.forward(x) + + assert torch.allclose( + x[..., ~bijection.transformed_event_mask], + z[..., ~bijection.transformed_event_mask], + atol=1e-5 + ) + + assert torch.all( + ~torch.isclose( + x[..., bijection.transformed_event_mask], + z[..., bijection.transformed_event_mask], + atol=1e-5 + ) + ) + + xr, log_det_inverse = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-5) + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-5)