Skip to content

Commit

Permalink
Simplify multi-space optimization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Oct 29, 2024
1 parent a5096a3 commit 8a7ef15
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions baybe/recommenders/pure/bayesian/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Collection, Iterable
from typing import TYPE_CHECKING, Any, ClassVar

import numpy as np
import pandas as pd
from attrs import define, field
from attrs.converters import optional as optional_c
Expand Down Expand Up @@ -447,10 +448,17 @@ def __str__(self) -> str:

def _optimize_subspaces_without_cardinality_constraints(
self, subspaces: Iterable[SubspaceContinuous], batch_size: int
) -> tuple[Tensor, Tensor]:
import torch
) -> tuple[Tensor, float]:
"""Find the optimum candidates from multiple subspaces.
Args:
subspaces: The subspaces to consider for the optimization.
batch_size: The number of points to be recommended.
acqf_values_all: list[Tensor] = []
Returns:
The batch of candidates and the corresponding acquisition value.
"""
acqf_values_all: list[float] = []
points_all: list[Tensor] = []

for subspace in subspaces:
Expand All @@ -459,19 +467,18 @@ def _optimize_subspaces_without_cardinality_constraints(
f = self._recommend_continuous_without_cardinality_constraints
points_i, acqf_values_i = f(subspace, batch_size)

# Append recommendation list and acquisition function values
# Append optimization results
points_all.append(points_i.unsqueeze(0))
acqf_values_all.append(acqf_values_i.unsqueeze(0))
acqf_values_all.append(acqf_values_i.item())

# # The optimization problem may be infeasible for certain inactive
# # parameters. The optimize_acqf raises a ValueError when the optimization
# # problem is infeasible.
# The optimization problem may be infeasible in certain subspaces
except ValueError:
pass

# Find the best option
points = torch.cat(points_all)[torch.argmax(torch.cat(acqf_values_all)), :]
acqf_values = torch.max(torch.cat(acqf_values_all))
# Find the best option f
best_idx = np.argmax(acqf_values_all)
points = points_all[best_idx]
acqf_values = acqf_values_all[best_idx]

return points, acqf_values

Expand Down

0 comments on commit 8a7ef15

Please sign in to comment.