diff --git a/dfedForest.py b/dfedForest.py index f48c7dd..2b1bae4 100644 --- a/dfedForest.py +++ b/dfedForest.py @@ -61,8 +61,8 @@ def loadForest(self): self.forestList.append(loads(readFile.read())) # Generate a new tree based on a dataset - def createTree(self): - self.newTree = tree.DecisionTreeRegressor() + def createTree(self,depth): + self.newTree = tree.DecisionTreeClassifier(max_depth=depth) self.newTree = self.newTree.fit(self.data, self.label) # Show all tree on the list