From bf0529e1f447dc4b42856b099ee900e4ce1d9079 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 14:43:01 +0100 Subject: [PATCH] Add learnable weights in FlowMixture --- normalizing_flows/flows.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index d3056fc..465047c 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -302,7 +302,7 @@ def regularization(self): class FlowMixture(BaseFlow): - def __init__(self, flows: List[Flow], weights: List[float] = None): + def __init__(self, flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False): super().__init__(event_shape=flows[0].event_shape) # Use uniform weights by default @@ -313,17 +313,18 @@ def __init__(self, flows: List[Flow], weights: List[float] = None): assert all([w > 0.0 for w in weights]) assert np.isclose(sum(weights), 1.0) - self.flows = flows - self.weights = torch.tensor(weights) - self.log_weights = torch.log(self.weights) - self.categorical_distribution = torch.distributions.Categorical(probs=self.weights) + self.flows = nn.ModuleList(flows) + if trainable_weights: + self.logit_weights = nn.Parameter(torch.log(torch.tensor(weights))) + else: + self.logit_weights = torch.log(torch.tensor(weights)) def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): flow_log_probs = torch.stack([flow.log_prob(x, context=context) for flow in self.flows]) # (n_flows, *batch_shape) batch_shape = flow_log_probs.shape[1:] - log_weights_reshaped = self.log_weights.view(-1, *([1] * len(batch_shape))) + log_weights_reshaped = self.logit_weights.view(-1, *([1] * len(batch_shape))) log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # batch_shape return log_prob @@ -335,18 +336,19 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re flow_samples.append(flow_x) flow_log_probs.append(flow_log_prob) - with torch.no_grad(): - flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) - categorical_samples = self.categorical_distribution.sample(sample_shape=torch.Size((n,))) # (n,) - one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) - one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape))) - # (n_flows, n, *event_shape) + flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) + categorical_samples = torch.distributions.Categorical(logits=self.logit_weights).sample( + sample_shape=torch.Size((n,)) + ) # (n,) + one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) + one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape))) + # (n_flows, n, *event_shape) - samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape) + samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape) if return_log_prob: flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n) - log_weights_reshaped = self.log_weights[:, None] # (n_flows, 1) + log_weights_reshaped = self.logit_weights[:, None] # (n_flows, 1) log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # (n,) return samples, log_prob else: