From 2c260d318d6e1237ead9e7d34a8ca7ffb092f2e7 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 15 Aug 2024 14:15:04 +0200 Subject: [PATCH] Fix mixture log prob --- torchflows/base_distributions/mixture.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchflows/base_distributions/mixture.py b/torchflows/base_distributions/mixture.py index 17f3a0f..a6e7217 100644 --- a/torchflows/base_distributions/mixture.py +++ b/torchflows/base_distributions/mixture.py @@ -38,8 +38,8 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: # We are assuming all components are normalized value = value.to(self.log_weights) batch_shape = get_batch_shape(value, self.event_shape) - log_probs = torch.zeros(*batch_shape, self.n_components).to(self.log_weights) - for i in range(self.n_components): + log_probs = torch.zeros(*batch_shape, len(self.components)).to(self.log_weights) + for i in range(len(self.components)): log_probs[..., i] = self.components[i].log_prob(value) sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] return torch.logsumexp(self.log_weights[sample_shape_mask] + log_probs, dim=-1)