diff --git a/baybe/recommenders/pure/bayesian/botorch.py b/baybe/recommenders/pure/bayesian/botorch.py index 077c648a2..ae40a0720 100644 --- a/baybe/recommenders/pure/bayesian/botorch.py +++ b/baybe/recommenders/pure/bayesian/botorch.py @@ -7,6 +7,7 @@ from collections.abc import Collection, Iterable from typing import TYPE_CHECKING, Any, ClassVar +import numpy as np import pandas as pd from attrs import define, field from attrs.converters import optional as optional_c @@ -447,10 +448,17 @@ def __str__(self) -> str: def _optimize_subspaces_without_cardinality_constraints( self, subspaces: Iterable[SubspaceContinuous], batch_size: int - ) -> tuple[Tensor, Tensor]: - import torch + ) -> tuple[Tensor, float]: + """Find the optimum candidates from multiple subspaces. + + Args: + subspaces: The subspaces to consider for the optimization. + batch_size: The number of points to be recommended. - acqf_values_all: list[Tensor] = [] + Returns: + The batch of candidates and the corresponding acquisition value. + """ + acqf_values_all: list[float] = [] points_all: list[Tensor] = [] for subspace in subspaces: @@ -459,19 +467,18 @@ def _optimize_subspaces_without_cardinality_constraints( f = self._recommend_continuous_without_cardinality_constraints points_i, acqf_values_i = f(subspace, batch_size) - # Append recommendation list and acquisition function values + # Append optimization results points_all.append(points_i.unsqueeze(0)) - acqf_values_all.append(acqf_values_i.unsqueeze(0)) + acqf_values_all.append(acqf_values_i.item()) - # # The optimization problem may be infeasible for certain inactive - # # parameters. The optimize_acqf raises a ValueError when the optimization - # # problem is infeasible. + # The optimization problem may be infeasible in certain subspaces except ValueError: pass - # Find the best option - points = torch.cat(points_all)[torch.argmax(torch.cat(acqf_values_all)), :] - acqf_values = torch.max(torch.cat(acqf_values_all)) + # Find the best option f + best_idx = np.argmax(acqf_values_all) + points = points_all[best_idx] + acqf_values = acqf_values_all[best_idx] return points, acqf_values