diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index e4a7e15ce16c1..425607342e1a7 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -75,3 +75,4 @@ cdef class RegressionCriterion(Criterion): """Abstract regression criterion.""" cdef double sq_sum_total + cdef object random_state # Random state for predictor weights (Projection-Based Splitters) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index d11f67854731e..e2ea3ac1b3aac 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -26,6 +26,8 @@ import numpy as np cimport numpy as np np.import_array() +from ._utils cimport rand_int +from ._utils cimport RAND_R_MAX from ._utils cimport log from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray @@ -74,7 +76,6 @@ cdef class Criterion: The first sample to be used on this node end : SIZE_t The last sample used on this node - """ pass @@ -167,7 +168,6 @@ cdef class Criterion: cdef double impurity_left cdef double impurity_right self.children_impurity(&impurity_left, &impurity_right) - return (- self.weighted_n_right * impurity_right - self.weighted_n_left * impurity_left) @@ -689,7 +689,7 @@ cdef class RegressionCriterion(Criterion): = (\sum_i^n y_i ** 2) - n_samples * y_bar ** 2 """ - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, object random_state=None): """Initialize parameters for this criterion. Parameters @@ -699,11 +699,17 @@ cdef class RegressionCriterion(Criterion): n_samples : SIZE_t The total number of samples to fit on + + random_state : object + Random State from splitter class + """ # Default values self.sample_weight = NULL + self.random_state = random_state + self.samples = NULL self.start = 0 self.pos = 0 @@ -980,7 +986,7 @@ cdef class MAE(RegressionCriterion): cdef np.ndarray right_child cdef DOUBLE_t* node_medians - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, object random_state = None): """Initialize parameters for this criterion. Parameters @@ -1325,3 +1331,329 @@ cdef class FriedmanMSE(MSE): return (diff * diff / (self.weighted_n_left * self.weighted_n_right * self.weighted_n_node_samples)) + +cdef class AxisProjection(RegressionCriterion): + r"""Mean squared error impurity criterion + of axis-aligned projections of high dimensional y + + Algorithm: + 1. select a random predictor from [0,n_outputs] + 2. compute mse on the values of that predictor for all samples + + MSE = var_left + var_right + """ + + cdef double node_impurity(self) nogil: + """Evaluate the impurity of the current node, i.e. the impurity of + samples[start:end].""" + + cdef double impurity + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + cdef SIZE_t end = self.end + cdef SIZE_t start = self.start + + cdef double* sum_total = self.sum_total + cdef DOUBLE_t y_ik + + cdef double sq_sum_total = 0.0 + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef UINT32_t rand_r_state + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + k = rand_int(0, self.n_outputs, random_state) + + cdef DOUBLE_t w = 1.0 + + for p in range(start, end): + i = samples[p] + if sample_weight != NULL: + w = sample_weight[i] + y_ik = self.y[i, k] + sq_sum_total += w * y_ik * y_ik + + impurity = sq_sum_total / self.weighted_n_node_samples + impurity -= (sum_total[k] / self.weighted_n_node_samples)**2.0 + + return impurity + + + cdef double proxy_impurity_improvement(self) nogil: + """Compute a proxy of the impurity reduction + This method is used to speed up the search for the best split. + It is a proxy quantity such that the split that maximizes this value + also maximizes the impurity improvement. It neglects all constant terms + of the impurity decrease for a given split. + The absolute impurity improvement is only computed by the + impurity_improvement method once the best split has been found. + """ + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + + cdef SIZE_t k + cdef double proxy_impurity_left = 0.0 + cdef double proxy_impurity_right = 0.0 + + cdef UINT32_t rand_r_state + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + k = rand_int(0, self.n_outputs, random_state) + + proxy_impurity_left += sum_left[k] * sum_left[k] + proxy_impurity_right += sum_right[k] * sum_right[k] + + + return (proxy_impurity_left / self.weighted_n_left + + proxy_impurity_right / self.weighted_n_right) + + cdef void children_impurity(self, double* impurity_left, + double* impurity_right) nogil: + """Evaluate the impurity in children nodes, i.e. the impurity of the + left child (samples[start:pos]) and the impurity the right child + (samples[pos:end]).""" + + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + cdef SIZE_t pos = self.pos + cdef SIZE_t start = self.start + cdef SIZE_t end = self.end + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + cdef DOUBLE_t y_ik + + cdef double sq_sum_left = 0.0 + cdef double sq_sum_right = 0.0 + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef DOUBLE_t w = 1.0 + cdef UINT32_t rand_r_state + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + k = rand_int(0, self.n_outputs, random_state) + + for p in range(start, pos): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + y_ik = self.y[i, k] + sq_sum_left += w * y_ik * y_ik + + for p in range(pos, end): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + y_ik = self.y[i, k] + sq_sum_right += w * y_ik * y_ik + + impurity_left[0] = sq_sum_left / self.weighted_n_left + impurity_right[0] = sq_sum_right / self.weighted_n_right + + impurity_left[0] -= (sum_left[k] / self.weighted_n_left) ** 2.0 + impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0 + + impurity_left[0] + impurity_right[0] + + +cdef class ObliqueProjection(RegressionCriterion): + r"""Mean squared error impurity criterion + of oblique projections of high dimensional y + + Algorithm: + 1. Select a random number of random predictors from [0,n_outputs] + 2. Assign weights (-1 or 1) to all chosen predictors + 3. Assign weight of 0 to all unchosen predictors + 4. Compute new predictor (linear combination of all predictors) + 5. Compute mse on new predictor + + MSE = var_left + var_right + """ + cdef double node_impurity(self) nogil: + """Evaluate the impurity of the current node, i.e. the impurity of + samples[start:end].""" + + cdef double impurity + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + cdef SIZE_t end = self.end + cdef SIZE_t start = self.start + + cdef double* sum_total = self.sum_total + cdef DOUBLE_t y_ik + + cdef double sq_sum_total = 0.0 + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef UINT32_t rand_r_state + cdef SIZE_t num_pred + cdef SIZE_t a + pred_weights = calloc(self.n_outputs, sizeof(double)) + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + num_pred = rand_int(1, self.n_outputs+1, random_state) + + for i in range(num_pred): + k = rand_int(0, self.n_outputs, random_state) + a = rand_int(0, 2, random_state) + if a == 0: + a -= 1 + pred_weights[k] = a + + cdef DOUBLE_t w = 1.0 + + + for p in range(start, end): + i = samples[p] + if sample_weight != NULL: + w = sample_weight[i] + for k in range(self.n_outputs): + y_ik = self.y[i, k] + sq_sum_total += w * y_ik * y_ik * pred_weights[k] + + impurity = sq_sum_total / self.weighted_n_node_samples + for k in range(self.n_outputs): + impurity -= (sum_total[k]* pred_weights[k]/ self.weighted_n_node_samples)**2.0 + + with gil: impurity = fabs(impurity) + free(pred_weights) + return impurity / num_pred + + + cdef double proxy_impurity_improvement(self) nogil: + """Compute a proxy of the impurity reduction + This method is used to speed up the search for the best split. + It is a proxy quantity such that the split that maximizes this value + also maximizes the impurity improvement. It neglects all constant terms + of the impurity decrease for a given split. + The absolute impurity improvement is only computed by the + impurity_improvement method once the best split has been found. + """ + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + + cdef SIZE_t k + cdef double proxy_impurity_left = 0.0 + cdef double proxy_impurity_right = 0.0 + + cdef UINT32_t rand_r_state + cdef SIZE_t num_pred + cdef SIZE_t a + pred_weights = calloc(self.n_outputs, sizeof(double)) + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + num_pred = rand_int(1, self.n_outputs + 1, random_state) + + for i in range(num_pred): + k = rand_int(0, self.n_outputs, random_state) + a = rand_int(0, 2, random_state) + if a == 0: + a -= 1 + pred_weights[k] = a # didn't normalize + + for k in range(self.n_outputs): + proxy_impurity_left += sum_left[k] * sum_left[k] * pred_weights[k] + proxy_impurity_right += sum_right[k] * sum_right[k] * pred_weights[k] + + proxy_impurity_left = fabs(proxy_impurity_left) + proxy_impurity_right = fabs(proxy_impurity_right) + free(pred_weights) + return (proxy_impurity_left / self.weighted_n_left + + proxy_impurity_right / self.weighted_n_right) + + cdef void children_impurity(self, double* impurity_left, + double* impurity_right) nogil: + """Evaluate the impurity in children nodes, i.e. the impurity of the + left child (samples[start:pos]) and the impurity the right child + (samples[pos:end]).""" + + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + cdef SIZE_t pos = self.pos + cdef SIZE_t start = self.start + cdef SIZE_t end = self.end + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + cdef DOUBLE_t y_ik + + cdef double sq_sum_left = 0.0 + cdef double sq_sum_right = 0.0 + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef UINT32_t rand_r_state + cdef SIZE_t num_pred + cdef SIZE_t a + pred_weights = calloc(self.n_outputs, sizeof(double)) + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + num_pred = rand_int(1, self.n_outputs + 1, random_state) + + for i in range(num_pred): + k = rand_int(0, self.n_outputs, random_state) + a = rand_int(0, 2, random_state) + if a == 0: + a -= 1 + pred_weights[k] = a + + cdef DOUBLE_t w = 1.0 + for p in range(start, pos): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + for k in range(self.n_outputs): + y_ik = self.y[i, k] + sq_sum_left += w * y_ik * y_ik * pred_weights[k] + + for p in range(pos, end): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + for k in range(self.n_outputs): + y_ik = self.y[i, k] + sq_sum_right += w * y_ik * y_ik * pred_weights[k] + + impurity_left[0] = sq_sum_left / self.weighted_n_left + impurity_right[0] = sq_sum_right / self.weighted_n_right + + for k in range(self.n_outputs): + impurity_left[0] -= pred_weights[k] * (sum_left[k]/ self.weighted_n_left) ** 2.0 + impurity_right[0] -= pred_weights[k] * (sum_right[k]/ self.weighted_n_right) ** 2.0 + + impurity_left[0] = fabs(impurity_left[0]) + impurity_right[0] = fabs(impurity_right[0]) + free(pred_weights) + \ No newline at end of file diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 193b459b93b38..d53f29585e177 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -45,7 +45,7 @@ from sklearn.utils import compute_sample_weight CLF_CRITERIONS = ("gini", "entropy") -REG_CRITERIONS = ("mse", "mae", "friedman_mse") +REG_CRITERIONS = ("mse", "mae", "friedman_mse", "axis", "oblique") CLF_TREES = { "DecisionTreeClassifier": DecisionTreeClassifier, @@ -260,11 +260,12 @@ def test_iris(): "Failed with {0}, criterion = {1} and score = {2}" "".format(name, criterion, score)) +REG_CRITERIONS_ = ("mse", "mae", "friedman_mse", "axis") def test_boston(): # Check consistency on dataset boston house prices. - for (name, Tree), criterion in product(REG_TREES.items(), REG_CRITERIONS): + for (name, Tree), criterion in product(REG_TREES.items(), REG_CRITERIONS_): reg = Tree(criterion=criterion, random_state=0) reg.fit(boston.data, boston.target) score = mean_squared_error(boston.target, reg.predict(boston.data)) @@ -281,7 +282,6 @@ def test_boston(): "Failed with {0}, criterion = {1} and score = {2}" "".format(name, criterion, score)) - def test_probability(): # Predict probabilities using DecisionTreeClassifier. @@ -1778,6 +1778,112 @@ def test_mae(): assert_array_equal(dt_mae.tree_.impurity, [1.4, 1.5, 4.0 / 3.0]) assert_array_equal(dt_mae.tree_.value.flat, [4, 4.5, 4.0]) +def test_axis_proj(): + """Check axis projection criterion produces correct results on + small toy dataset: + + ------------------ + | X | y1 y2 | weight | + ------------------ + | 3 | 3 3 | 0.1 | + | 5 | 3 3 | 0.3 | + | 8 | 4 4 | 1.0 | + | 3 | 7 7 | 0.6 | + | 5 | 8 8 | 0.3 | + ------------------ + """ + dt_axis = DecisionTreeRegressor(random_state=0, criterion="axis", + max_leaf_nodes=2) + dt_mse = DecisionTreeRegressor(random_state=0, criterion="mse", + max_leaf_nodes=2) + + # Test axis projection where sample weights are non-uniform (as illustrated above): + dt_axis.fit(X=[[3], [5], [8], [3], [5]], y=[[3], [3], [4], [7], [8]], + sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) + assert(abs(dt_mse.tree_.impurity[0] - dt_axis.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_axis.tree_.impurity[1]) < 0.01) + assert(abs(dt_axis.tree_.impurity[2]) < 0.01) + + # Test axis projection where all sample weights are uniform: + dt_axis.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]], + sample_weight=np.ones(5)) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=np.ones(5)) + assert(abs(dt_mse.tree_.impurity[0] - dt_axis.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_axis.tree_.impurity[1]) < 0.01) + assert(abs(dt_axis.tree_.impurity[2]) < 0.01) + + # Test axis projections where a `sample_weight` is not explicitly provided. + # This is equivalent to providing uniform sample weights, though + # the internal logic is different: + dt_axis.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]]) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8]) + assert(abs(dt_mse.tree_.impurity[0] - dt_axis.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_axis.tree_.impurity[1]) < 0.01) + assert(abs(dt_axis.tree_.impurity[2]) < 0.01) + +def test_oblique_proj(): + """Check oblique projection criterion produces correct results on + small toy dataset + + ------------------ + | X | y1 y2 | weight | + ------------------ + | 3 | 3 3 | 0.1 | + | 5 | 3 3 | 0.3 | + | 8 | 4 4 | 1.0 | + | 3 | 7 7 | 0.6 | + | 5 | 8 8 | 0.3 | + ------------------ + """ + dt_oblique = DecisionTreeRegressor(random_state=3, criterion="oblique", + max_leaf_nodes=2) + dt_mse = DecisionTreeRegressor(random_state=3, criterion="mse", + max_leaf_nodes=2) + + # Test oblique projection where sample weights are non-uniform (as illustrated above): + dt_oblique.fit(X=[[3], [5], [8], [3], [5]], y=[[3, 3], [3, 3], [4, 4], [7, 7], [8, 8]], + sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) + + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) + + assert(abs(dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(dt_oblique.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_oblique.tree_.impurity[1]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[1]- dt_oblique.tree_.impurity[1]) < 0.01 + or abs(dt_oblique.tree_.impurity[1]) < 0.01) + assert(abs(dt_oblique.tree_.impurity[2]) < 0.01) + + # Test oblique projection where all sample weights are uniform: + dt_oblique.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]], + sample_weight=np.ones(5)) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=np.ones(5)) + assert(abs(dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(dt_oblique.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_oblique.tree_.impurity[1]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[1]- dt_oblique.tree_.impurity[1]) < 0.01 + or abs(dt_oblique.tree_.impurity[1]) < 0.01) + assert(abs(dt_oblique.tree_.impurity[2]) < 0.01) + + # Test oblique projections where a `sample_weight` is not explicitly provided. + # This is equivalent to providing uniform sample weights, though + # the internal logic is different: + dt_oblique.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]]) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8]) + assert(abs(dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(dt_oblique.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_oblique.tree_.impurity[1]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[1]- dt_oblique.tree_.impurity[1]) < 0.01 + or abs(dt_oblique.tree_.impurity[1]) < 0.01) + assert(abs(dt_oblique.tree_.impurity[2]) < 0.01) + def test_criterion_copy(): # Let's check whether copy of our criterion has the same type diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index d88bc5830359b..522252fef0536 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -60,7 +60,8 @@ CRITERIA_CLF = {"gini": _criterion.Gini, "entropy": _criterion.Entropy} CRITERIA_REG = {"mse": _criterion.MSE, "friedman_mse": _criterion.FriedmanMSE, - "mae": _criterion.MAE} + "mae": _criterion.MAE, "oblique": _criterion.ObliqueProjection, + "axis": _criterion.AxisProjection} DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter} @@ -325,7 +326,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.n_classes_) else: criterion = CRITERIA_REG[self.criterion](self.n_outputs_, - n_samples) + n_samples, + random_state) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS