Skip to content

Commit

Permalink
Add second option: iterate through combinatorial list
Browse files Browse the repository at this point in the history
  • Loading branch information
Waschenbacher committed Jul 3, 2024
1 parent 391fe65 commit 0671dd2
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 7 deletions.
24 changes: 24 additions & 0 deletions baybe/constraints/continuous.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Continuous constraints."""

import math
from itertools import combinations
from math import comb

import numpy as np
from attrs import define
Expand Down Expand Up @@ -46,6 +48,28 @@ class ContinuousCardinalityConstraint(
):
"""Class for continuous cardinality constraints."""

@property
def combinatorial_counts_zero_parameters(self) -> int:
"""Return the total number of all possible combinations of zero parameters."""
combinatorial_counts = 0
for i_zeros in range(
len(self.parameters) - self.max_cardinality,
len(self.parameters) - self.min_cardinality + 1,
):
combinatorial_counts += comb(len(self.parameters), i_zeros)
return combinatorial_counts

@property
def combinatorial_zero_parameters(self) -> list[tuple[str, ...]]:
"""Return a combinatorial list of all possible zero parameters."""
combinatorial_zeros = []
for i_zeros in range(
len(self.parameters) - self.max_cardinality,
len(self.parameters) - self.min_cardinality + 1,
):
combinatorial_zeros.extend(combinations(self.parameters, i_zeros))
return combinatorial_zeros

def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]:
"""Sample sets of inactive parameters according to the cardinality constraints.
Expand Down
33 changes: 28 additions & 5 deletions baybe/recommenders/pure/bayesian/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
sample_numerical_df,
)

N_RESTART_CARDINALITY = 5
N_ITER_THRESHOLD = 10


@define(kw_only=True)
Expand Down Expand Up @@ -185,11 +185,34 @@ def _recommend_continuous_on_subspace(
if len(subspace_continuous.constraints_cardinality):
acqf_values_all: list[Tensor] = []
points_all: list[Tensor] = []
for _ in range(N_RESTART_CARDINALITY):
# Randomly set some parameters inactive
inactive_params_sample = (
subspace_continuous._sample_inactive_parameters(1)[0]

# When the size of the full list of inactive parameters is not too large,
# we can iterate through the full list; otherwise we randomly set some
# parameters inactive.
_iterator = (
subspace_continuous.combinatorial_zero_parameters
if (
combinatorial_counts
:= subspace_continuous.combinatorial_counts_zero_parameters
)
<= N_ITER_THRESHOLD
else range(N_ITER_THRESHOLD)
)

for inactive_params_generator in _iterator:
if combinatorial_counts <= N_ITER_THRESHOLD:
# Iterate through the combinations of all possible inactive
# parameters.
inactive_params_sample = {
param
for sublist in inactive_params_generator
for param in sublist
}
else:
# Randomly set some parameters inactive
inactive_params_sample = (
subspace_continuous._sample_inactive_parameters(1)[0]
)

if len(inactive_params_sample):
# Turn inactive parameters to fixed features (used as input in
Expand Down
35 changes: 33 additions & 2 deletions baybe/searchspace/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import annotations

import warnings
from collections.abc import Collection, Sequence
from itertools import chain
from collections.abc import Collection, Iterable, Sequence
from functools import reduce
from itertools import chain, product
from typing import TYPE_CHECKING, Any, cast

import numpy as np
Expand Down Expand Up @@ -111,6 +112,36 @@ def constraints_cardinality(self) -> tuple[ContinuousCardinalityConstraint, ...]
if isinstance(c, ContinuousCardinalityConstraint)
)

@property
def combinatorial_counts_zero_parameters(self) -> int:
"""Return the total number of all possible combinations of zero parameters."""
# Note that both continuous subspace and continuous cardinality constraint
# have this property. This property is the counts for the subspace
# parameters; while the latter one is the counts only for that constraint.
if self.constraints_cardinality:
return reduce(
lambda x, y: x * y,
[
con.combinatorial_counts_zero_parameters
for con in self.constraints_cardinality
],
)
else:
return 0

@property
def combinatorial_zero_parameters(self) -> Iterable[tuple[str, ...]]:
"""Return a combinatorial list of all possible zero parameters on subspace."""
# The comments on the difference in `combinatorial_counts_zero_parameters`
# applies here as well.
if self.constraints_cardinality:
return product(
*[
con.combinatorial_zero_parameters
for con in self.constraints_cardinality
]
)

@constraints_nonlin.validator
def _validate_constraints_nonlin(self, _, __) -> None:
"""Validate nonlinear constraints."""
Expand Down

0 comments on commit 0671dd2

Please sign in to comment.