Skip to content

Commit

Permalink
FlowMixture fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Feb 9, 2024
1 parent 474460f commit 4fb4a00
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
12 changes: 10 additions & 2 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,13 @@ def regularization(self):


class FlowMixture(BaseFlow):
def __init__(self, flows: List[Flow], weights: List[float]):
def __init__(self, flows: List[Flow], weights: List[float] = None):
super().__init__(event_shape=flows[0].event_shape)

# Use uniform weights by default
if weights is None:
weights = [1.0 / len(flows)] * len(flows)

assert len(weights) == len(flows)
assert all([w > 0.0 for w in weights])
assert np.isclose(sum(weights), 1.0)
Expand Down Expand Up @@ -334,7 +339,10 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re
flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape)
categorical_samples = self.categorical_distribution.sample(sample_shape=torch.Size((n,))) # (n,)
one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n)
samples = torch.sum(one_hot * flow_samples, dim=0) # (n, *event_shape)
one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape)))
# (n_flows, n, *event_shape)

samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape)

if return_log_prob:
flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n)
Expand Down
69 changes: 69 additions & 0 deletions test/test_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from normalizing_flows.flows import FlowMixture, Flow
from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF
import torch


def test_basic():
torch.manual_seed(0)

n_data = 100
n_dim = 10
x = torch.randn(size=(n_data, n_dim))

mixture = FlowMixture([
Flow(RealNVP(event_shape=(n_dim,))),
Flow(NICE(event_shape=(n_dim,))),
Flow(CouplingRQNSF(event_shape=(n_dim,)))
])

log_prob = mixture.log_prob(x)
assert log_prob.shape == (n_data,)
assert torch.all(torch.isfinite(log_prob))

x_sampled = mixture.sample(n_data)
assert x_sampled.shape == x.shape
assert torch.all(torch.isfinite(x_sampled))


def test_medium():
torch.manual_seed(0)

n_data = 1000
n_dim = 100
x = torch.randn(size=(n_data, n_dim))

mixture = FlowMixture([
Flow(RealNVP(event_shape=(n_dim,))),
Flow(NICE(event_shape=(n_dim,))),
Flow(CouplingRQNSF(event_shape=(n_dim,)))
])

log_prob = mixture.log_prob(x)
assert log_prob.shape == (n_data,)
assert torch.all(torch.isfinite(log_prob))

x_sampled = mixture.sample(n_data)
assert x_sampled.shape == x.shape
assert torch.all(torch.isfinite(x_sampled))


def test_complex_event():
torch.manual_seed(0)

n_data = 1000
event_shape = (2, 3, 4, 5)
x = torch.randn(size=(n_data, *event_shape))

mixture = FlowMixture([
Flow(RealNVP(event_shape=event_shape)),
Flow(NICE(event_shape=event_shape)),
Flow(CouplingRQNSF(event_shape=event_shape))
])

log_prob = mixture.log_prob(x)
assert log_prob.shape == (n_data,)
assert torch.all(torch.isfinite(log_prob))

x_sampled = mixture.sample(n_data)
assert x_sampled.shape == x.shape
assert torch.all(torch.isfinite(x_sampled))

0 comments on commit 4fb4a00

Please sign in to comment.