Skip to content

Commit

Permalink
Lazy load torch in surrogates (#165)
Browse files Browse the repository at this point in the history
Delay torch import until really needed
  • Loading branch information
AdrianSosic authored Apr 8, 2024
2 parents 58cd1e3 + 38709ed commit e741273
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 54 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `torch` numeric types are now loaded lazily
- Reorganized acquisition.py into `acquisition` subpackage
- `torch` is imported lazily in `surrogates`

### Fixed
- `n_task_params` now evaluates to 1 if `task_idx == 0`
Expand Down Expand Up @@ -44,7 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Full lookup backtesting example now tests different substance encodings
- Replaced unmaintained `mordred` dependency by `mordredcommunity`
- `SearchSpace`s now use `ndarray` instead of `Tensor`
- `SearchSpace`s now use `ndarray` instead of `Tensor`

### Fixed
- `from_simplex` now efficiently validated in `Campaign.validate_config`
Expand Down
7 changes: 4 additions & 3 deletions baybe/acquisition/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import Optional

import torch
from attr import define
from botorch.acquisition import AcquisitionFunction
from torch import Tensor, cat, squeeze
from torch import Tensor


@define
Expand Down Expand Up @@ -56,7 +57,7 @@ def _lift_partial_part(self, partial_part: Tensor) -> Tensor:
disc_part = partial_part
cont_part = pinned_part
# Concat the parts and return the concatenated point
full_point = cat((disc_part, cont_part), -1)
full_point = torch.cat((disc_part, cont_part), -1)
return full_point

def __call__(self, variable_part: Tensor) -> Tensor:
Expand Down Expand Up @@ -87,6 +88,6 @@ def set_X_pending(self, X_pending: Optional[Tensor]):
"""
if X_pending is not None: # Lift point to hybrid space and add additional dim
X_pending = self._lift_partial_part(X_pending)
X_pending = squeeze(X_pending, -2)
X_pending = torch.squeeze(X_pending, -2)
# Now use the original set_X_pending function
self.acqf.set_X_pending(X_pending)
25 changes: 25 additions & 0 deletions baybe/acquisition/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Utilities for acquisition functions."""

from functools import partial

from botorch.acquisition import (
ExpectedImprovement,
PosteriorMean,
ProbabilityOfImprovement,
UpperConfidenceBound,
qExpectedImprovement,
qProbabilityOfImprovement,
qUpperConfidenceBound,
)

acquisition_function_mapping = {
"PM": PosteriorMean,
"PI": ProbabilityOfImprovement,
"EI": ExpectedImprovement,
"UCB": partial(UpperConfidenceBound, beta=1.0),
"qEI": qExpectedImprovement,
"qPI": qProbabilityOfImprovement,
"qUCB": partial(qUpperConfidenceBound, beta=1.0),
"VarUCB": partial(UpperConfidenceBound, beta=100.0),
"qVarUCB": partial(qUpperConfidenceBound, beta=100.0),
}
26 changes: 3 additions & 23 deletions baybe/recommenders/pure/bayesian/base.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
"""Base class for all Bayesian recommenders."""

from abc import ABC
from functools import partial
from typing import Callable, Literal, Optional

import pandas as pd
from attrs import define, field
from botorch.acquisition import (
AcquisitionFunction,
ExpectedImprovement,
PosteriorMean,
ProbabilityOfImprovement,
UpperConfidenceBound,
qExpectedImprovement,
qProbabilityOfImprovement,
qUpperConfidenceBound,
)
from botorch.acquisition import AcquisitionFunction

from baybe.acquisition import debotorchize
from baybe.acquisition.utils import acquisition_function_mapping
from baybe.recommenders.pure.base import PureRecommender
from baybe.searchspace import SearchSpace
from baybe.surrogates import _ONNX_INSTALLED, GaussianProcessSurrogate
Expand Down Expand Up @@ -53,18 +44,7 @@ def _get_acquisition_function_cls(
Returns:
The debotorchized acquisition function class.
"""
mapping = {
"PM": PosteriorMean,
"PI": ProbabilityOfImprovement,
"EI": ExpectedImprovement,
"UCB": partial(UpperConfidenceBound, beta=1.0),
"qEI": qExpectedImprovement,
"qPI": qProbabilityOfImprovement,
"qUCB": partial(qUpperConfidenceBound, beta=1.0),
"VarUCB": partial(UpperConfidenceBound, beta=100.0),
"qVarUCB": partial(qUpperConfidenceBound, beta=100.0),
}
fun = debotorchize(mapping[self.acquisition_function_cls])
fun = debotorchize(acquisition_function_mapping[self.acquisition_function_cls])
return fun

def setup_acquisition_function(
Expand Down
11 changes: 8 additions & 3 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
"""Base functionality for all BayBE surrogates."""

from __future__ import annotations

import gc
import sys
from abc import ABC, abstractmethod
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar

import torch
from attr import define, field
from torch import Tensor

from baybe.searchspace import SearchSpace
from baybe.serialization import SerialMixin, converter, unstructure_base
from baybe.surrogates.utils import _prepare_inputs, _prepare_targets
from baybe.utils.basic import get_subclasses

if TYPE_CHECKING:
from torch import Tensor

# Define constants
_MIN_VARIANCE = 1e-6
_WRAPPER_MODELS = (
Expand Down Expand Up @@ -69,6 +72,8 @@ def posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
The posterior means and posterior covariance matrices of the t-batched
candidate points.
"""
import torch

# Prepare the input
candidates = _prepare_inputs(candidates)

Expand Down
17 changes: 13 additions & 4 deletions baybe/surrogates/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
It is planned to solve this issue in the future.
"""
from __future__ import annotations

from typing import Callable, ClassVar
from typing import TYPE_CHECKING, Callable, ClassVar

import torch
from attrs import define, field, validators
from torch import Tensor
from attrs import define, field, resolve_types, validators

from baybe.exceptions import ModelParamsNotSupportedError
from baybe.parameters import (
Expand All @@ -37,6 +36,9 @@
except ImportError:
_ONNX_INSTALLED = False

if TYPE_CHECKING:
from torch import Tensor


def register_custom_architecture(
joint_posterior_attr: bool = False,
Expand Down Expand Up @@ -153,6 +155,8 @@ def __attrs_post_init__(self) -> None:

@batchify
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
import torch

model_inputs = {
self.onnx_input_name: candidates.numpy().astype(DTypeFloatONNX)
}
Expand Down Expand Up @@ -211,3 +215,8 @@ def validate_compatibility(cls, searchspace: SearchSpace) -> None:
f"a one-dimensional computational representation or "
f"{CustomDiscreteParameter.__name__}."
)

# FIXME: This manual resolve should not be necessary if the classes are declared
# properly. Potentially related to the conditional class definition, which should
# vanish as well.
resolve_types(CustomONNXSurrogate)
11 changes: 8 additions & 3 deletions baybe/surrogates/gaussian_process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Gaussian process surrogates."""

from typing import Any, ClassVar, Optional
from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Optional

import torch
from attr import define, field
from botorch import fit_gpytorch_mll
from botorch.models import SingleTaskGP
Expand All @@ -12,12 +13,14 @@
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.priors import GammaPrior
from torch import Tensor

from baybe.searchspace import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.validation import get_model_params_validator

if TYPE_CHECKING:
from torch import Tensor


@define
class GaussianProcessSurrogate(Surrogate):
Expand Down Expand Up @@ -49,6 +52,8 @@ def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None:
# See base class.

import torch

# identify the indexes of the task and numeric dimensions
# TODO: generalize to multiple task parameters
task_idx = searchspace.task_idx
Expand Down
11 changes: 8 additions & 3 deletions baybe/surrogates/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
Since we plan to refactor the surrogates, this part of the documentation will be
available in the future. Thus, please have a look in the source code directly.
"""
from __future__ import annotations

from typing import Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional

import torch
from attr import define, field
from sklearn.linear_model import ARDRegression
from torch import Tensor

from baybe.searchspace import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify, catch_constant_targets, scale_model
from baybe.surrogates.validation import get_model_params_validator

if TYPE_CHECKING:
from torch import Tensor


@catch_constant_targets
@scale_model
Expand Down Expand Up @@ -47,6 +49,9 @@ class BayesianLinearSurrogate(Surrogate):
@batchify
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
# See base class.

import torch

# Get predictions
dists = self._model.predict(candidates.numpy(), return_std=True)

Expand Down
12 changes: 9 additions & 3 deletions baybe/surrogates/naive.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Naive surrogates."""

from typing import ClassVar, Optional
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Optional

import torch
from attr import define, field
from torch import Tensor

from baybe.searchspace import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify

if TYPE_CHECKING:
from torch import Tensor


@define
class MeanPredictionSurrogate(Surrogate):
Expand All @@ -33,6 +36,9 @@ class MeanPredictionSurrogate(Surrogate):
@batchify
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
# See base class.

import torch

# TODO: use target value bounds for covariance scaling when explicitly provided
mean = self.target_value * torch.ones([len(candidates)])
var = torch.ones(len(candidates))
Expand Down
11 changes: 8 additions & 3 deletions baybe/surrogates/ngboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
Since we plan to refactor the surrogates, this part of the documentation will be
available in the future. Thus, please have a look in the source code directly.
"""
from __future__ import annotations

from typing import Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional

import torch
from attr import define, field
from ngboost import NGBRegressor
from torch import Tensor

from baybe.searchspace import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify, catch_constant_targets, scale_model
from baybe.surrogates.validation import get_model_params_validator

if TYPE_CHECKING:
from torch import Tensor


@catch_constant_targets
@scale_model
Expand Down Expand Up @@ -53,6 +55,9 @@ def __attrs_post_init__(self):
@batchify
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
# See base class.

import torch

# Get predictions
dists = self._model.pred_dist(candidates)

Expand Down
10 changes: 7 additions & 3 deletions baybe/surrogates/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
Since we plan to refactor the surrogates, this part of the documentation will be
available in the future. Thus, please have a look in the source code directly.
"""
from __future__ import annotations

from typing import Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional

import numpy as np
import torch
from attr import define, field
from sklearn.ensemble import RandomForestRegressor
from torch import Tensor

from baybe.searchspace import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify, catch_constant_targets, scale_model
from baybe.surrogates.validation import get_model_params_validator

if TYPE_CHECKING:
from torch import Tensor


@catch_constant_targets
@scale_model
Expand Down Expand Up @@ -49,6 +51,8 @@ class RandomForestSurrogate(Surrogate):
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
# See base class.

import torch

# Evaluate all trees
# NOTE: explicit conversion to ndarray is needed due to a pytorch issue:
# https://github.com/pytorch/pytorch/pull/51731
Expand Down
Loading

0 comments on commit e741273

Please sign in to comment.