Skip to content

Commit

Permalink
fix memory leak in autoregression
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 21, 2024
1 parent d390763 commit 3e0d3b5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bnpm/automatic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _objective(self, trial: optuna.trial.Trial) -> float:
if len(loss_train_all) == 1:
loss_train, loss_test, loss = loss_train_all[0], loss_test_all[0], loss_all[0]
else:
if isinstance(loss_train_all[0], (np.ndarray, np.generic)):
if isinstance(loss_train_all[0], (np.ndarray, np.generic, float, int)):
stack = np.stack
elif isinstance(loss_train_all[0], torch.Tensor):
stack = torch.stack
Expand Down Expand Up @@ -686,8 +686,8 @@ def __call__(
elif isinstance(y_train_pred, torch.Tensor):
from torch.nn.functional import mse_loss
assert (sample_weight_test is None) and (sample_weight_train is None), 'sample weights not supported for torch tensors.'
loss_train = mse_loss(y_train_pred, y_train_true, reduction='mean')
loss_test = mse_loss(y_test_pred, y_test_true, reduction='mean')
loss_train = mse_loss(y_train_pred, y_train_true, reduction='mean').item()
loss_test = mse_loss(y_test_pred, y_test_true, reduction='mean').item()
loss = self.fn_penalty_testTrainRatio(loss_test, loss_train)
else:
raise ValueError(f'Expected y_train_pred to be of type np.ndarray or torch.Tensor, but got type {type(y_train_pred)}.')
Expand Down

0 comments on commit 3e0d3b5

Please sign in to comment.