Skip to content

Commit

Permalink
Fix shape error when sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 3, 2023
1 parent 2962718 commit 70f9b30
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
13 changes: 10 additions & 3 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from normalizing_flows.bijections.base import Bijection
from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection
from normalizing_flows.regularization import reconstruction_error
from normalizing_flows.utils import flatten_event, get_batch_shape
from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event


class Flow(nn.Module):
Expand All @@ -24,6 +24,13 @@ def base_log_prob(self, z):
log_prob = self.base.log_prob(zf)
return log_prob

def base_sample(self, sample_shape):
z_flat = self.base.sample(sample_shape)
z = unflatten_event(z_flat, self.bijection.event_shape)
return z



def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None):
if context is not None:
assert context.shape[0] == x.shape[0]
Expand All @@ -46,11 +53,11 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re
:return:
"""
if context is not None:
z = self.base.sample(sample_shape=torch.Size((n, len(context))))
z = self.base_sample(sample_shape=torch.Size((n, len(context))))
context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape
assert z.shape[:2] == context.shape[:2]
else:
z = self.base.sample(sample_shape=torch.Size((n,)))
z = self.base_sample(sample_shape=torch.Size((n,)))
if no_grad:
z = z.detach()
with torch.no_grad():
Expand Down
11 changes: 11 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

from normalizing_flows import Flow
from normalizing_flows.bijections import RealNVP


def test_real_nvp():
torch.manual_seed(0)
f = Flow(RealNVP(event_shape=torch.Size((2, 3, 5, 7))))
x = f.sample(10)
assert x.shape == (10, 2, 3, 5, 7)

0 comments on commit 70f9b30

Please sign in to comment.