Skip to content

Commit

Permalink
added RF surrogate
Browse files Browse the repository at this point in the history
  • Loading branch information
akhilsnair2017 committed Mar 28, 2024
1 parent 705f412 commit 5e5cf95
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/mobo_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
import pandas as pd
from sklearn.ensemble import RandomForestRegressor

from .data.cm_featurizer import get_coulomb_matrix
from.data.soap_featurizer import get_soap
Expand Down Expand Up @@ -144,6 +145,17 @@ def get_surrogate_model(self, acq):
fit_gpytorch_mll(mll)

return model

def surrogate_RF(self,acq):

features = self.train_indices[acq]
targets = self.correct_sign(self.targets[self.train_indices[acq]])
model = RandomForestRegressor(max_depth=10, min_samples_split=4, n_estimators=len(self.X_train), random_state=13)
model.fit(features,targets)
#preds = [tree.predict(self.X_test.values) for tree in rf.estimators_]
#mean_pred, sigma = np.mean(preds, axis=0), np.std(preds, axis=0)
#return mean_pred, sigma
return model

def correct_sign(self, Y):
y_copy = Y.copy()
Expand Down

0 comments on commit 5e5cf95

Please sign in to comment.