Skip to content

Commit

Permalink
Merge pull request #203 from alan-turing-institute/lgbm
Browse files Browse the repository at this point in the history
LightGBM
  • Loading branch information
mastoffel authored Mar 7, 2024
2 parents deb526b + 4d60b2e commit 5517f55
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 231 deletions.
2 changes: 0 additions & 2 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def setup(
Number of jobs to run in parallel. `None` means 1, `-1` means using all processors.
model_subset : list
List of models to use. If None, uses all models in MODEL_REGISTRY.
Currently, any of: SecondOrderPolynomial, RBF, RandomForest, GradientBoosting,
GaussianProcess, SupportVectorMachines, XGBoost
log_to_file : bool
Whether to log to file.
"""
Expand Down
4 changes: 2 additions & 2 deletions autoemulate/emulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .gaussian_process import GaussianProcess
from .gaussian_process_mogp import GaussianProcessMOGP
from .gradient_boosting import GradientBoosting
from .light_gbm import LightGBM
from .neural_net_sk import NeuralNetSk
from .neural_net_torch import NeuralNetTorch
from .polynomials import SecondOrderPolynomial
from .random_forest import RandomForest
from .rbf import RadialBasisFunctions
from .support_vector_machines import SupportVectorMachines
from .xgboost import XGBoost

MODEL_REGISTRY = {
SecondOrderPolynomial().model_name: SecondOrderPolynomial(),
Expand All @@ -16,7 +16,7 @@
GradientBoosting().model_name: GradientBoosting(),
GaussianProcess().model_name: GaussianProcess(),
SupportVectorMachines().model_name: SupportVectorMachines(),
XGBoost().model_name: XGBoost(),
LightGBM().model_name: LightGBM(),
NeuralNetTorch(module="mlp").model_name: NeuralNetTorch(module="mlp"),
NeuralNetTorch(module="rbf").model_name: NeuralNetTorch(module="rbf"),
NeuralNetSk().model_name: NeuralNetSk(),
Expand Down
144 changes: 144 additions & 0 deletions autoemulate/emulators/light_gbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import numpy as np
from lightgbm import LGBMRegressor
from scipy.stats import loguniform
from scipy.stats import randint
from scipy.stats import uniform
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_X_y
from skopt.space import Categorical
from skopt.space import Integer
from skopt.space import Real


class LightGBM(BaseEstimator, RegressorMixin):
"""LightGBM Emulator.
Wraps LightGBM regression from LightGBM.
"""

def __init__(
self,
boosting_type="gbdt",
num_leaves=31,
max_depth=-1,
learning_rate=0.1,
n_estimators=100,
subsample_for_bin=200000,
objective=None,
class_weight=None,
min_split_gain=0.0,
min_child_weight=0.001,
min_child_samples=20,
subsample=1.0,
# subsample_freq=0.0,
colsample_bytree=1.0,
reg_alpha=0.0,
reg_lambda=0.0,
random_state=None,
n_jobs=1,
importance_type="split",
verbose=-1,
):
"""Initializes a LightGBM object."""
self.boosting_type = boosting_type
self.num_leaves = num_leaves
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.subsample_for_bin = subsample_for_bin
self.objective = objective
self.class_weight = class_weight
self.min_split_gain = min_split_gain
self.min_child_weight = min_child_weight
self.min_child_samples = min_child_samples
self.subsample = subsample
# self.subsample_freq = subsample_freq
self.colsample_bytree = colsample_bytree
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.random_state = random_state
self.n_jobs = n_jobs
self.importance_type = importance_type
self.verbose = verbose

def fit(self, X, y, sample_weight=None, **kwargs):
"""Fits the emulator to the data."""
X, y = check_X_y(
X, y, multi_output=self._more_tags()["multioutput"], y_numeric=True
)

self.n_features_in_ = X.shape[1]

self.model_ = LGBMRegressor(
boosting_type=self.boosting_type,
num_leaves=self.num_leaves,
max_depth=self.max_depth,
learning_rate=self.learning_rate,
n_estimators=self.n_estimators,
subsample_for_bin=self.subsample_for_bin,
objective=self.objective,
class_weight=self.class_weight,
min_split_gain=self.min_split_gain,
min_child_weight=self.min_child_weight,
min_child_samples=self.min_child_samples,
subsample=self.subsample,
colsample_bytree=self.colsample_bytree,
reg_alpha=self.reg_alpha,
reg_lambda=self.reg_lambda,
random_state=self.random_state,
n_jobs=self.n_jobs,
importance_type=self.importance_type,
verbose=self.verbose,
)

self.model_.fit(X, y, sample_weight=sample_weight)
self.is_fitted_ = True
return self

def predict(self, X):
"""Predicts the output of the emulator for a given input."""
X = check_array(X)
check_is_fitted(self, "is_fitted_")
y_pred = self.model_.predict(X)
return y_pred

def get_grid_params(self, search_type="random"):
"""Returns the grid parameters of the emulator."""
param_space_random = {
"boosting_type": ["gbdt", "dart"],
"num_leaves": randint(10, 100),
"max_depth": randint(-1, 12),
"learning_rate": loguniform(0.001, 0.1),
"n_estimators": randint(50, 1000),
# "colsample_bytree": uniform(0.5, 1.0),
"reg_alpha": loguniform(0.001, 1),
"reg_lambda": loguniform(0.001, 1),
}

param_space_bayes = {
"boosting_type": Categorical(["gbdt", "dart"]),
"num_leaves": Integer(10, 100),
"max_depth": Integer(-1, 12),
"learning_rate": Real(0.001, 0.1, prior="log-uniform"),
"n_estimators": Integer(50, 1000),
# "colsample_bytree": Real(0.5, 1.0),
"reg_alpha": Real(0.001, 1, prior="log-uniform"),
"reg_lambda": Real(0.001, 1, prior="log-uniform"),
}

if search_type == "random":
param_space = param_space_random
elif search_type == "bayes":
param_space = param_space_bayes

return param_space

@property
def model_name(self):
return self.__class__.__name__

def _more_tags(self):
return {"multioutput": False}
179 changes: 0 additions & 179 deletions autoemulate/emulators/xgboost.py

This file was deleted.

Loading

0 comments on commit 5517f55

Please sign in to comment.