-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_selection.py
47 lines (40 loc) · 1.77 KB
/
model_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from sklearn.model_selection import GridSearchCV
import pandas as pd
from os.path import normpath
class GridSearchCV(GridSearchCV):
"""docstring for GridSearchCV"""
def __init__(self, est_class, est_params, param_grid, cv=None, n_jobs=1,
error_score="raise", save_path=None):
self.est_class = est_class
self.est_params = est_params
self.param_grid = param_grid
self.n_jobs = n_jobs
self.estimator = est_class(est_params)
self.set_save_path(save_path)
self.cv = cv
if cv is not None:
self.cv_obj = cv["class"](**cv["params"])
else:
self.cv_obj = None
super(GridSearchCV, self).__init__(self.estimator, param_grid,
cv=self.cv_obj,
n_jobs=n_jobs,
error_score=error_score)
def fit(self, X, y=None, groups=None, **fit_params):
super(GridSearchCV, self).fit(X, y, groups, **fit_params)
data = {
"best_params_": self.best_params_,
"mean_test_score": self.cv_results_["mean_test_score"],
"std_test_score": self.cv_results_["std_test_score"],
}
df = pd.DataFrame.from_dict(pd.io.json.json_normalize(data))
df.to_csv(normpath(self.save_path+"GridSearchCV.csv"))
if hasattr(self.best_estimator_, "save_path"):
self.best_estimator_.set_save_path(self.save_path)
self.best_estimator_.fit(X, y)
return self
def set_save_path(self, save_path):
self.save_path = save_path
if (hasattr(self, "best_estimator_") and
hasattr(self.best_estimator_, "save_path")):
self.best_estimator_.set_save_path(save_path)