Skip to content

Commit

Permalink
Fix batch apply for complex batch shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 30, 2024
1 parent 5d30fb9 commit dcbb653
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions torchflows/bijections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class Bijection(nn.Module):
"""Bijection class.
"""

def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
context_shape: Union[torch.Size, Tuple[int, ...]] = None,
Expand Down Expand Up @@ -50,18 +51,26 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
"""
raise NotImplementedError

@staticmethod
def batch_apply(fn, batch_size, *args):
dataset = TensorDataset(*args)
def batch_apply(self, fn, batch_size, x, context=None):
batch_shape = x.shape[:-len(self.event_shape)]

if context is None:
x_flat = torch.flatten(x, start_dim=0, end_dim=len(batch_shape) - 1)
dataset = TensorDataset(x_flat)
else:
x_flat = torch.flatten(x, start_dim=0, end_dim=len(batch_shape) - 1)
context_flat = torch.flatten(context, start_dim=0, end_dim=len(batch_shape) - 1)
dataset = TensorDataset(x_flat, context_flat)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
outputs = []
log_dets = []
for batch in data_loader:
batch_out, batch_log_det = fn(*batch)
outputs.append(batch_out)
log_dets.append(batch_log_det)
outputs = torch.cat(outputs, dim=0)
log_dets = torch.cat(log_dets, dim=0)
outputs = torch.cat(outputs, dim=0).view_as(x)
log_dets = torch.cat(log_dets, dim=0).view(*batch_shape)
return outputs, log_dets

def batch_forward(self, x: torch.Tensor, batch_size: int, context: torch.Tensor = None):
Expand Down Expand Up @@ -89,6 +98,7 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor
def regularization(self):
return 0.0


def invert(bijection: Bijection) -> Bijection:
"""Swap the forward and inverse methods of the input bijection.
Expand All @@ -104,6 +114,7 @@ class BijectiveComposition(Bijection):
"""
Composition of bijections. Inherits from Bijection.
"""

def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
layers: List[Bijection],
Expand Down

0 comments on commit dcbb653

Please sign in to comment.