Skip to content

Commit

Permalink
Fix mixture log prob
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 15, 2024
1 parent 32dc061 commit 2c260d3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchflows/base_distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
# We are assuming all components are normalized
value = value.to(self.log_weights)
batch_shape = get_batch_shape(value, self.event_shape)
log_probs = torch.zeros(*batch_shape, self.n_components).to(self.log_weights)
for i in range(self.n_components):
log_probs = torch.zeros(*batch_shape, len(self.components)).to(self.log_weights)
for i in range(len(self.components)):
log_probs[..., i] = self.components[i].log_prob(value)
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))]
return torch.logsumexp(self.log_weights[sample_shape_mask] + log_probs, dim=-1)
Expand Down

0 comments on commit 2c260d3

Please sign in to comment.