diff --git a/src/mobo_qm9.py b/src/mobo_qm9.py index 4796f61..359cd3f 100644 --- a/src/mobo_qm9.py +++ b/src/mobo_qm9.py @@ -11,6 +11,7 @@ from botorch.models.transforms.outcome import Standardize from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood import pandas as pd +from sklearn.ensemble import RandomForestRegressor from .data.cm_featurizer import get_coulomb_matrix from.data.soap_featurizer import get_soap @@ -144,6 +145,17 @@ def get_surrogate_model(self, acq): fit_gpytorch_mll(mll) return model + + def surrogate_RF(self,acq): + + features = self.train_indices[acq] + targets = self.correct_sign(self.targets[self.train_indices[acq]]) + model = RandomForestRegressor(max_depth=10, min_samples_split=4, n_estimators=len(self.X_train), random_state=13) + model.fit(features,targets) + #preds = [tree.predict(self.X_test.values) for tree in rf.estimators_] + #mean_pred, sigma = np.mean(preds, axis=0), np.std(preds, axis=0) + #return mean_pred, sigma + return model def correct_sign(self, Y): y_copy = Y.copy()