Skip to content

Commit

Permalink
Merge pull request #52 from arnauqb/fix_jacobian
Browse files Browse the repository at this point in the history
Suggestion for catching shape error in Jacobian computations
  • Loading branch information
joelnmdyer authored Nov 11, 2024
2 parents 0dbe560 + 0fafca7 commit e3b2afe
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions blackbirds/infer/vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ def __init__(
self.tensorboard_log_dir = tensorboard_log_dir
self.log_tensorboard = log_tensorboard

def _check_loss_scalar(self, data):
one_sample = self.loss(self.posterior_estimator.sample(1), data)
multiple_samples = self.loss(self.posterior_estimator.sample(2), data)
assert one_sample.dim() == 0 and multiple_samples.dim() == 1, "Loss should be a scalar value (i.e., a 0D torch.tensor)"

def step(self, data):
"""
Performs one training step.
Expand Down Expand Up @@ -480,6 +485,7 @@ def run(
- `n_epochs`: The number of epochs to run the calibrator for.
- `max_epochs_without_improvement`: The number of epochs without improvement after which the calibrator stops.
"""
self._check_loss_scalar(data)
if mpi_rank == 0 and self.log_tensorboard:
self.writer = SummaryWriter(log_dir=self.tensorboard_log_dir)
if self.initialize_estimator_to_prior:
Expand Down

0 comments on commit e3b2afe

Please sign in to comment.