diff --git a/blackbirds/infer/vi.py b/blackbirds/infer/vi.py index 84490ed..9f3e71b 100644 --- a/blackbirds/infer/vi.py +++ b/blackbirds/infer/vi.py @@ -507,8 +507,8 @@ def run( "Loss/regularisation", regularisation_loss, epoch ) torch.save(self.best_estimator_state_dict, "last_estimator.pt") - if loss < self.best_loss: - self.best_loss = loss + if total_loss < self.best_loss: + self.best_loss = total_loss self.best_estimator_state_dict = deepcopy( self.posterior_estimator.state_dict() )