-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #252 from alan-turing-institute/pr-243
Pr 243
- Loading branch information
Showing
6 changed files
with
432 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .early_stopping_criterion import EarlyStoppingCustom | ||
from .poly_mean import PolyMean | ||
from .polynomial_features import PolynomialFeatures |
64 changes: 64 additions & 0 deletions
64
autoemulate/emulators/gaussian_process_utils/early_stopping_criterion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy as np | ||
from skorch.callbacks import EarlyStopping | ||
|
||
|
||
class EarlyStoppingCustom(EarlyStopping): | ||
"""Callback for stopping training when scores don't improve. | ||
Stop training early if a specified `monitor` metric did not | ||
improve in `patience` number of epochs by at least `threshold`. | ||
**Note**: This version is virtually identical to `EarlyStopping`, | ||
with the difference being that the method `_calc_new_threshold`, | ||
is corrected to ensure monotonicity. | ||
also see https://github.com/skorch-dev/skorch/pull/1065 | ||
Parameters | ||
---------- | ||
monitor : str (default='valid_loss') | ||
Value of the history to monitor to decide whether to stop | ||
training or not. The value is expected to be double and is | ||
commonly provided by scoring callbacks such as | ||
:class:`skorch.callbacks.EpochScoring`. | ||
lower_is_better : bool (default=True) | ||
Whether lower scores should be considered better or worse. | ||
patience : int (default=5) | ||
Number of epochs to wait for improvement of the monitor value | ||
until the training process is stopped. | ||
threshold : int (default=1e-4) | ||
Ignore score improvements smaller than `threshold`. | ||
threshold_mode : str (default='rel') | ||
One of `rel`, `abs`. Decides whether the `threshold` value is | ||
interpreted in absolute terms or as a fraction of the best | ||
score so far (relative) | ||
sink : callable (default=print) | ||
The target that the information about early stopping is | ||
sent to. By default, the output is printed to stdout, but the | ||
sink could also be a logger or :func:`~skorch.utils.noop`. | ||
load_best: bool (default=False) | ||
Whether to restore module weights from the epoch with the best value of | ||
the monitored quantity. If False, the module weights obtained at the | ||
last step of training are used. Note that only the module is restored. | ||
Use the ``Checkpoint`` callback with the :attr:`~Checkpoint.load_best` | ||
argument set to ``True`` if you need to restore the whole object. | ||
""" | ||
|
||
def _calc_new_threshold(self, score): | ||
"""Determine threshold based on score.""" | ||
if self.threshold_mode == "rel": | ||
abs_threshold_change = self.threshold * np.abs(score) | ||
else: | ||
abs_threshold_change = self.threshold | ||
|
||
if self.lower_is_better: | ||
new_threshold = score - abs_threshold_change | ||
else: | ||
new_threshold = score + abs_threshold_change | ||
return new_threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import gpytorch | ||
import torch | ||
|
||
from .polynomial_features import PolynomialFeatures | ||
|
||
|
||
class PolyMean(gpytorch.means.Mean): | ||
"""A geneneral polynomial mean module to be used to construct | ||
`guassian_process_torch` emulators. | ||
Parameters | ||
-------- | ||
degree : int | ||
The degree of the polynomial for which we are defining | ||
the mapping. | ||
input_size : int | ||
The number of features to be mapped. | ||
barch_shape : int | ||
bias : bool | ||
Flag for including a bias in the defnition of the polymial. | ||
If set to `False` polynomial includes weights only. | ||
""" | ||
|
||
def __init__(self, degree, input_size, batch_shape=torch.Size(), bias=True): | ||
super().__init__() | ||
self.degree = degree | ||
self.input_size = input_size | ||
|
||
self.poly = PolynomialFeatures(self.degree, self.input_size) | ||
self.poly.fit() | ||
|
||
n_weights = len(self.poly.indices) | ||
self.register_parameter( | ||
name="weights", | ||
parameter=torch.nn.Parameter(torch.randn(*batch_shape, n_weights, 1)), | ||
) | ||
|
||
if bias: | ||
self.register_parameter( | ||
name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)) | ||
) | ||
else: | ||
self.bias = None | ||
|
||
def forward(self, x): | ||
x_ = self.poly.transform(x) | ||
res = x_.matmul(self.weights).squeeze(-1) | ||
if self.bias is not None: | ||
res = res + self.bias | ||
return res | ||
|
||
def __repr__(self): | ||
return f"Polymean(degree={self.degree}, input_size={self.input_size})" |
74 changes: 74 additions & 0 deletions
74
autoemulate/emulators/gaussian_process_utils/polynomial_features.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import itertools | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
class PolynomialFeatures: | ||
""" | ||
This class is used to map an existing feature set to a | ||
polynomial feature set. | ||
Examples | ||
------- | ||
Initialize the class to map the feature `X` (`n1` samples x `n2` features): | ||
>>> pf = PolynomialFeatures(degree=2, input_size=X.shape[1]) | ||
Fit the instance in order to predefine the features that need to be multiplied to create the new features. | ||
>>> pf.fit() | ||
Generate the new polynomial feature set: | ||
>>> X_deg_2 = pf.transform(X) | ||
Parameters | ||
-------- | ||
degree : int | ||
The degree of the polynomial for which we are defining | ||
the mapping. | ||
input_size : int | ||
The number of features to be mapped. | ||
""" | ||
|
||
def __init__(self, degree: int, input_size: int): | ||
assert degree > 0, "`degree` input must be greater than 0." | ||
assert ( | ||
input_size > 0 | ||
), "`input_size`, which defines the number of features, for has to be greate than 0" | ||
self.degree = degree | ||
self.indices = None | ||
self.input_size = input_size | ||
|
||
def fit(self): | ||
x = list(range(self.input_size)) | ||
|
||
d = self.degree | ||
L = [] | ||
while d > 1: | ||
l = [list(p) for p in itertools.product(x, repeat=d)] | ||
for li in l: | ||
li.sort() | ||
L += list(map(list, np.unique(l, axis=0))) | ||
d -= 1 | ||
L += [[i] for i in x] | ||
|
||
Ls = [] | ||
for d in range(1, self.degree + 1): | ||
ld = [] | ||
for l in L: | ||
if len(l) == d: | ||
ld.append(l) | ||
ld.sort() | ||
Ls += ld | ||
self.indices = Ls | ||
|
||
def transform(self, x): | ||
if not self.indices: | ||
raise ValueError( | ||
"self.indices is set to None. Did you forget to call 'fit'?" | ||
) | ||
|
||
x_ = torch.stack([torch.prod(x[..., i], dim=-1) for i in self.indices], dim=-1) | ||
return x_ |
Oops, something went wrong.