-
Notifications
You must be signed in to change notification settings - Fork 0
/
pipeline.py
29 lines (24 loc) · 980 Bytes
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sklearn.pipeline import Pipeline
class Pipeline(Pipeline):
"""docstring for Pipeline"""
def __init__(self, class_list, save_path=None):
self.class_list = class_list
self.steps = self.load_steps(class_list)
super(Pipeline, self).__init__(self.steps)
self.set_save_path(save_path)
def load_steps(self, class_list):
steps = []
for dict_ in class_list:
name = dict_["class"].__name__
if "params" in dict_:
params = dict_["params"]
steps.append((name, dict_["class"](**params)))
else:
steps.append((name, dict_["class"]()))
return steps
def set_save_path(self, save_path):
self.save_path = save_path
for dict_ in self.class_list:
if hasattr(dict_["class"], "set_save_path"):
param = {dict_["class"].__name__+"__save_path": save_path}
self.set_params(**param)