diff --git a/torchflows/bijections/base.py b/torchflows/bijections/base.py index 02e7b57..57c0ca6 100644 --- a/torchflows/bijections/base.py +++ b/torchflows/bijections/base.py @@ -11,6 +11,7 @@ class Bijection(nn.Module): """Bijection class. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, @@ -50,9 +51,17 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. """ raise NotImplementedError - @staticmethod - def batch_apply(fn, batch_size, *args): - dataset = TensorDataset(*args) + def batch_apply(self, fn, batch_size, x, context=None): + batch_shape = x.shape[:-len(self.event_shape)] + + if context is None: + x_flat = torch.flatten(x, start_dim=0, end_dim=len(batch_shape) - 1) + dataset = TensorDataset(x_flat) + else: + x_flat = torch.flatten(x, start_dim=0, end_dim=len(batch_shape) - 1) + context_flat = torch.flatten(context, start_dim=0, end_dim=len(batch_shape) - 1) + dataset = TensorDataset(x_flat, context_flat) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) outputs = [] log_dets = [] @@ -60,8 +69,8 @@ def batch_apply(fn, batch_size, *args): batch_out, batch_log_det = fn(*batch) outputs.append(batch_out) log_dets.append(batch_log_det) - outputs = torch.cat(outputs, dim=0) - log_dets = torch.cat(log_dets, dim=0) + outputs = torch.cat(outputs, dim=0).view_as(x) + log_dets = torch.cat(log_dets, dim=0).view(*batch_shape) return outputs, log_dets def batch_forward(self, x: torch.Tensor, batch_size: int, context: torch.Tensor = None): @@ -89,6 +98,7 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor def regularization(self): return 0.0 + def invert(bijection: Bijection) -> Bijection: """Swap the forward and inverse methods of the input bijection. @@ -104,6 +114,7 @@ class BijectiveComposition(Bijection): """ Composition of bijections. Inherits from Bijection. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], layers: List[Bijection],