From 0c9335da006adeb8c5d3ce9d61b4eb47fa4ae5fa Mon Sep 17 00:00:00 2001 From: Morgan Sanchez Date: Thu, 19 Dec 2019 03:02:31 -0500 Subject: [PATCH] make criterion more memory efficient and adjust tests accordingly --- sklearn/tree/_criterion.pyx | 208 +++++++++++++++----------------- sklearn/tree/tests/test_tree.py | 205 +++++++------------------------ 2 files changed, 141 insertions(+), 272 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 874083b8d6112..e2ea3ac1b3aac 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -168,7 +168,6 @@ cdef class Criterion: cdef double impurity_left cdef double impurity_right self.children_impurity(&impurity_left, &impurity_right) - return (- self.weighted_n_right * impurity_right - self.weighted_n_left * impurity_left) @@ -1333,7 +1332,6 @@ cdef class FriedmanMSE(MSE): return (diff * diff / (self.weighted_n_left * self.weighted_n_right * self.weighted_n_node_samples)) - cdef class AxisProjection(RegressionCriterion): r"""Mean squared error impurity criterion of axis-aligned projections of high dimensional y @@ -1344,24 +1342,28 @@ cdef class AxisProjection(RegressionCriterion): MSE = var_left + var_right """ + cdef double node_impurity(self) nogil: """Evaluate the impurity of the current node, i.e. the impurity of samples[start:end].""" - cdef double impurity = 0.0 + + cdef double impurity cdef DOUBLE_t* sample_weight = self.sample_weight cdef SIZE_t* samples = self.samples cdef SIZE_t end = self.end cdef SIZE_t start = self.start - cdef double mean_pred = 0.0 + cdef double* sum_total = self.sum_total cdef DOUBLE_t y_ik + cdef double sq_sum_total = 0.0 + cdef SIZE_t i cdef SIZE_t p - cdef SIZE_t k + cdef SIZE_t k cdef UINT32_t rand_r_state - with gil: + with gil: rand_r_state = self.random_state.randint(0, RAND_R_MAX) cdef UINT32_t* random_state = &rand_r_state @@ -1374,26 +1376,20 @@ cdef class AxisProjection(RegressionCriterion): if sample_weight != NULL: w = sample_weight[i] y_ik = self.y[i, k] - mean_pred += y_ik / (end - start) - - for p in range(start, end): - i = samples[p] - if sample_weight != NULL: - w = sample_weight[i] - impurity += (mean_pred - self.y[i, k]) * (mean_pred - self.y[i, k]) * w - impurity /= self.weighted_n_node_samples + sq_sum_total += w * y_ik * y_ik + + impurity = sq_sum_total / self.weighted_n_node_samples + impurity -= (sum_total[k] / self.weighted_n_node_samples)**2.0 return impurity cdef double proxy_impurity_improvement(self) nogil: """Compute a proxy of the impurity reduction - This method is used to speed up the search for the best split. It is a proxy quantity such that the split that maximizes this value also maximizes the impurity improvement. It neglects all constant terms of the impurity decrease for a given split. - The absolute impurity improvement is only computed by the impurity_improvement method once the best split has been found. """ @@ -1407,15 +1403,16 @@ cdef class AxisProjection(RegressionCriterion): cdef UINT32_t rand_r_state - with gil: + with gil: rand_r_state = self.random_state.randint(0, RAND_R_MAX) cdef UINT32_t* random_state = &rand_r_state - k = rand_int(0, self.n_outputs, random_state) + k = rand_int(0, self.n_outputs, random_state) proxy_impurity_left += sum_left[k] * sum_left[k] proxy_impurity_right += sum_right[k] * sum_right[k] - + + return (proxy_impurity_left / self.weighted_n_left + proxy_impurity_right / self.weighted_n_right) @@ -1424,62 +1421,57 @@ cdef class AxisProjection(RegressionCriterion): """Evaluate the impurity in children nodes, i.e. the impurity of the left child (samples[start:pos]) and the impurity the right child (samples[pos:end]).""" - + cdef DOUBLE_t* sample_weight = self.sample_weight cdef SIZE_t* samples = self.samples cdef SIZE_t pos = self.pos cdef SIZE_t start = self.start cdef SIZE_t end = self.end + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right cdef DOUBLE_t y_ik - impurity_left[0] = 0.0 - impurity_right[0] = 0.0 - cdef double mean_pred_left = 0.0 - cdef double mean_pred_right = 0.0 + cdef double sq_sum_left = 0.0 + cdef double sq_sum_right = 0.0 cdef SIZE_t i cdef SIZE_t p cdef SIZE_t k + cdef DOUBLE_t w = 1.0 cdef UINT32_t rand_r_state - with gil: + with gil: rand_r_state = self.random_state.randint(0, RAND_R_MAX) cdef UINT32_t* random_state = &rand_r_state - k = rand_int(0, self.n_outputs, random_state) + k = rand_int(0, self.n_outputs, random_state) - cdef DOUBLE_t w = 1.0 for p in range(start, pos): i = samples[p] - if sample_weight != NULL: - w = sample_weight[i] - y_ik = self.y[i, k] - mean_pred_left += y_ik / (pos - start) - for p in range(start, pos): - i = samples[p] if sample_weight != NULL: w = sample_weight[i] - impurity_left[0] += ((mean_pred_left - self.y[i, k]) - * (mean_pred_left - self.y[i, k]) * w)/self.weighted_n_left + y_ik = self.y[i, k] + sq_sum_left += w * y_ik * y_ik for p in range(pos, end): i = samples[p] + if sample_weight != NULL: w = sample_weight[i] y_ik = self.y[i, k] - mean_pred_right += y_ik / (end - pos) + sq_sum_right += w * y_ik * y_ik - for p in range(pos, end): - i = samples[p] - if sample_weight != NULL: - w = sample_weight[i] - impurity_right[0] += ((mean_pred_right - self.y[i, k]) - * (mean_pred_right - self.y[i, k]) * w)/self.weighted_n_right + impurity_left[0] = sq_sum_left / self.weighted_n_left + impurity_right[0] = sq_sum_right / self.weighted_n_right + + impurity_left[0] -= (sum_left[k] / self.weighted_n_left) ** 2.0 + impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0 impurity_left[0] impurity_right[0] + cdef class ObliqueProjection(RegressionCriterion): r"""Mean squared error impurity criterion @@ -1497,24 +1489,26 @@ cdef class ObliqueProjection(RegressionCriterion): cdef double node_impurity(self) nogil: """Evaluate the impurity of the current node, i.e. the impurity of samples[start:end].""" - cdef double impurity = 0.0 + + cdef double impurity cdef DOUBLE_t* sample_weight = self.sample_weight cdef SIZE_t* samples = self.samples cdef SIZE_t end = self.end cdef SIZE_t start = self.start - cdef double* pred = calloc(end-start, sizeof(double)) - cdef double mean_pred = 0.0 + cdef double* sum_total = self.sum_total cdef DOUBLE_t y_ik + cdef double sq_sum_total = 0.0 + cdef SIZE_t i cdef SIZE_t p - cdef SIZE_t k + cdef SIZE_t k cdef UINT32_t rand_r_state - cdef SIZE_t num_pred + cdef SIZE_t num_pred cdef SIZE_t a pred_weights = calloc(self.n_outputs, sizeof(double)) - + with gil: rand_r_state = self.random_state.randint(0, RAND_R_MAX) cdef UINT32_t* random_state = &rand_r_state @@ -1526,42 +1520,34 @@ cdef class ObliqueProjection(RegressionCriterion): a = rand_int(0, 2, random_state) if a == 0: a -= 1 - pred_weights[k] = a # didn't normalize + pred_weights[k] = a cdef DOUBLE_t w = 1.0 + for p in range(start, end): i = samples[p] if sample_weight != NULL: w = sample_weight[i] for k in range(self.n_outputs): y_ik = self.y[i, k] - # sum over all predictors with pred weights - pred[p] += y_ik * pred_weights[k] + sq_sum_total += w * y_ik * y_ik * pred_weights[k] - for p in range(start, end): - # sum over all samples to get mean of new predictor - with gil: mean_pred += pred[p] / (end - start) - - for p in range(start, end): - i = samples[p] - if sample_weight != NULL: - w = sample_weight[i] - with gil: impurity += (mean_pred - pred[p]) * (mean_pred - pred[p]) * w - impurity /= self.weighted_n_node_samples + impurity = sq_sum_total / self.weighted_n_node_samples + for k in range(self.n_outputs): + impurity -= (sum_total[k]* pred_weights[k]/ self.weighted_n_node_samples)**2.0 + with gil: impurity = fabs(impurity) free(pred_weights) - free(pred) - return impurity + return impurity / num_pred + cdef double proxy_impurity_improvement(self) nogil: """Compute a proxy of the impurity reduction - This method is used to speed up the search for the best split. It is a proxy quantity such that the split that maximizes this value also maximizes the impurity improvement. It neglects all constant terms of the impurity decrease for a given split. - The absolute impurity improvement is only computed by the impurity_improvement method once the best split has been found. """ @@ -1573,10 +1559,31 @@ cdef class ObliqueProjection(RegressionCriterion): cdef double proxy_impurity_left = 0.0 cdef double proxy_impurity_right = 0.0 + cdef UINT32_t rand_r_state + cdef SIZE_t num_pred + cdef SIZE_t a + pred_weights = calloc(self.n_outputs, sizeof(double)) + + with gil: + rand_r_state = self.random_state.randint(0, RAND_R_MAX) + cdef UINT32_t* random_state = &rand_r_state + + num_pred = rand_int(1, self.n_outputs + 1, random_state) + + for i in range(num_pred): + k = rand_int(0, self.n_outputs, random_state) + a = rand_int(0, 2, random_state) + if a == 0: + a -= 1 + pred_weights[k] = a # didn't normalize + for k in range(self.n_outputs): - proxy_impurity_left += sum_left[k] * sum_left[k] - proxy_impurity_right += sum_right[k] * sum_right[k] + proxy_impurity_left += sum_left[k] * sum_left[k] * pred_weights[k] + proxy_impurity_right += sum_right[k] * sum_right[k] * pred_weights[k] + proxy_impurity_left = fabs(proxy_impurity_left) + proxy_impurity_right = fabs(proxy_impurity_right) + free(pred_weights) return (proxy_impurity_left / self.weighted_n_left + proxy_impurity_right / self.weighted_n_right) @@ -1585,87 +1592,68 @@ cdef class ObliqueProjection(RegressionCriterion): """Evaluate the impurity in children nodes, i.e. the impurity of the left child (samples[start:pos]) and the impurity the right child (samples[pos:end]).""" - + cdef DOUBLE_t* sample_weight = self.sample_weight cdef SIZE_t* samples = self.samples cdef SIZE_t pos = self.pos cdef SIZE_t start = self.start cdef SIZE_t end = self.end + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right cdef DOUBLE_t y_ik - impurity_left[0] = 0.0 - impurity_right[0] = 0.0 - cdef double* pred_left = calloc(pos-start, sizeof(double)) - cdef double* pred_right = calloc(end-pos, sizeof(double)) - cdef double mean_pred_left = 0.0 - cdef double mean_pred_right = 0.0 + cdef double sq_sum_left = 0.0 + cdef double sq_sum_right = 0.0 cdef SIZE_t i cdef SIZE_t p - cdef SIZE_t k + cdef SIZE_t k cdef UINT32_t rand_r_state - cdef SIZE_t num_pred - cdef SIZE_t a + cdef SIZE_t num_pred + cdef SIZE_t a pred_weights = calloc(self.n_outputs, sizeof(double)) - with gil: + with gil: rand_r_state = self.random_state.randint(0, RAND_R_MAX) cdef UINT32_t* random_state = &rand_r_state - num_pred = rand_int(0, self.n_outputs, random_state) + num_pred = rand_int(1, self.n_outputs + 1, random_state) for i in range(num_pred): k = rand_int(0, self.n_outputs, random_state) a = rand_int(0, 2, random_state) if a == 0: a -= 1 - pred_weights[k] = a # didn't normalize + pred_weights[k] = a cdef DOUBLE_t w = 1.0 for p in range(start, pos): i = samples[p] + if sample_weight != NULL: w = sample_weight[i] for k in range(self.n_outputs): y_ik = self.y[i, k] - # sum over all predictors with pred weights - pred_left[p] += y_ik * pred_weights[k] - - for p in range(start, pos): - # sum over all samples to get mean of new predictor - mean_pred_left += pred_left[p] / (pos - start) - - for p in range(start, pos): - i = samples[p] - if sample_weight != NULL: - w = sample_weight[i] - impurity_left[0] += ((mean_pred_left - pred_left[p]) - * (mean_pred_left - pred_left[p]) * w)/self.weighted_n_left - + sq_sum_left += w * y_ik * y_ik * pred_weights[k] + for p in range(pos, end): i = samples[p] + if sample_weight != NULL: w = sample_weight[i] for k in range(self.n_outputs): y_ik = self.y[i, k] - # sum over all predictors with pred weights - pred_right[p - pos] += y_ik * pred_weights[k] - # sum over all samples to get mean of new predictor + sq_sum_right += w * y_ik * y_ik * pred_weights[k] - for p in range(pos, end): - mean_pred_right += pred_right[p-pos] / (end - pos) - - for p in range(pos, end): - i = samples[p] - if sample_weight != NULL: - w = sample_weight[i] - impurity_right[0] += ((mean_pred_right - pred_right[p - pos]) - * (mean_pred_right - pred_right[p-pos]) * w) / self.weighted_n_right + impurity_left[0] = sq_sum_left / self.weighted_n_left + impurity_right[0] = sq_sum_right / self.weighted_n_right - impurity_left[0] - impurity_right[0] + for k in range(self.n_outputs): + impurity_left[0] -= pred_weights[k] * (sum_left[k]/ self.weighted_n_left) ** 2.0 + impurity_right[0] -= pred_weights[k] * (sum_right[k]/ self.weighted_n_right) ** 2.0 + impurity_left[0] = fabs(impurity_left[0]) + impurity_right[0] = fabs(impurity_right[0]) free(pred_weights) - free(pred_left) - free(pred_right) \ No newline at end of file + \ No newline at end of file diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index b2bec2ec42b90..d53f29585e177 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1779,7 +1779,8 @@ def test_mae(): assert_array_equal(dt_mae.tree_.value.flat, [4, 4.5, 4.0]) def test_axis_proj(): - """Check axis projection criterion produces correct results on small toy dataset: + """Check axis projection criterion produces correct results on + small toy dataset: ------------------ | X | y1 y2 | weight | @@ -1790,204 +1791,83 @@ def test_axis_proj(): | 3 | 7 7 | 0.6 | | 5 | 8 8 | 0.3 | ------------------ - |sum wt:| 2.3 | - ------------------ - - Mean1 = 5 - Mean2 = 5 - - For all the samples, we can get the total error by summing: - (Mean1 - y1)^2 * weight or (Mean2 - y2)^2 * weight - - I.e., total error = (5 - 3)^2 * 0.1) - + (5 - 3)^2 * 0.3) - + (5 - 4)^2 * 1.0) - + (5 - 7)^2 * 0.6) - + (5 - 8)^2 * 0.3) - = 0.4 + 1.2 + 1.0 + 2.4 + 2.7 - = 7.7 - - Impurity = Total error / total weight - = 7.7 / 2.3 - = 3.3478260869565 - ----------------- - - From this root node, the next best split is between X values of 5 and 8. - Thus, we have left and right child nodes: - - LEFT RIGHT - ----------------------- ----------------------- - | X | y1 y2 | weight | | X | y1 y2 | weight | - ----------------------- ----------------------- - | 3 | 3 3 | 0.1 | | 8 | 4 4 | 1.0 | - | 3 | 7 7 | 0.6 | ----------------------- - | 5 | 3 3 | 0.3 | |sum wt:| 1.0 | - | 5 | 8 8 | 0.3 | ----------------------- - ----------------------- - |sum wt:| 1.3 | - ----------------------- - - 5.0625 + 3.0625 + 5.0625 + 7.5625 / 4 + 0 = 5.1875 - 4 + 4.667 = 8.667 - - Impurity is found in the same way: - Left node Mean1 = Mean2 = 5.25 - Total error = ((5.25 - 3)^2 * 0.1) - + ((5.25 - 7)^2 * 0.6) - + ((5.25 - 3)^2 * 0.3) - + ((5.25 - 8)^2 * 0.3) - = 6.13125 - - Left Impurity = Total error / total weight - = 6.13125 / 1.3 - = 4.716346153846154 - ------------------- - - Likewise for Right node: - Right node Mean1 = Mean2 = 4 - Total error = ((4 - 4)^2 * 1.0) - = 0 - - Right Impurity = Total error / total weight - = 0 / 1.0 - = 0.0 - ------ """ dt_axis = DecisionTreeRegressor(random_state=0, criterion="axis", max_leaf_nodes=2) + dt_mse = DecisionTreeRegressor(random_state=0, criterion="mse", + max_leaf_nodes=2) + # Test axis projection where sample weights are non-uniform (as illustrated above): dt_axis.fit(X=[[3], [5], [8], [3], [5]], y=[[3], [3], [4], [7], [8]], sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) - assert(abs(7.7 / 2.3 - dt_axis.tree_.impurity[0]) < 0.01) - assert(abs(6.13125 / 1.3 - dt_axis.tree_.impurity[1]) < 0.01) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) + assert(abs(dt_mse.tree_.impurity[0] - dt_axis.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_axis.tree_.impurity[1]) < 0.01) assert(abs(dt_axis.tree_.impurity[2]) < 0.01) # Test axis projection where all sample weights are uniform: dt_axis.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]], sample_weight=np.ones(5)) - assert(abs(22.0 / 5.0 - dt_axis.tree_.impurity[0]) < 0.01) - assert(abs(20.75 / 4.0 - dt_axis.tree_.impurity[1]) < 0.01) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=np.ones(5)) + assert(abs(dt_mse.tree_.impurity[0] - dt_axis.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_axis.tree_.impurity[1]) < 0.01) assert(abs(dt_axis.tree_.impurity[2]) < 0.01) # Test axis projections where a `sample_weight` is not explicitly provided. # This is equivalent to providing uniform sample weights, though # the internal logic is different: dt_axis.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]]) - assert(abs(22.0 / 5.0 - dt_axis.tree_.impurity[0]) < 0.01) - assert(abs(20.75 / 4.0 - dt_axis.tree_.impurity[1]) < 0.01) + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8]) + assert(abs(dt_mse.tree_.impurity[0] - dt_axis.tree_.impurity[0]) < 0.01) + assert(abs(dt_mse.tree_.impurity[1] - dt_axis.tree_.impurity[1]) < 0.01) assert(abs(dt_axis.tree_.impurity[2]) < 0.01) def test_oblique_proj(): - """Check oblique projection criterion produces correct results on small toy dataset: - - ----------------------- + """Check oblique projection criterion produces correct results on + small toy dataset + + ------------------ | X | y1 y2 | weight | - ----------------------- + ------------------ | 3 | 3 3 | 0.1 | | 5 | 3 3 | 0.3 | | 8 | 4 4 | 1.0 | | 3 | 7 7 | 0.6 | | 5 | 8 8 | 0.3 | - ----------------------- - |sum wt:| 2.3 | - ----------------------- - - Mean1 = 5 - Mean2 = 5 - - For all the samples, we can get the total error by summing: - (Mean1 - y1)^2 * weight or (Mean2 - y)^2 * weight - - I.e., error1 = (5 - 3)^2 * 0.1) - + (5 - 3)^2 * 0.3) - + (5 - 4)^2 * 1.0) - + (5 - 7)^2 * 0.6) - + (5 - 8)^2 * 0.3) - = 0.4 + 1.2 + 1.0 + 2.4 + 2.7 - = 7.7 - error_tot = 15.4 - - Impurity = error / total weight - = 7.7 / 2.3 - = 3.3478260869565 - or - = 15.4 / 2.3 - = 6.6956521739130 - or - = 0.0 - ----------------- - - From this root node, the next best split is between X values of 5 and 8. - Thus, we have left and right child nodes: - - LEFT RIGHT - ----------------------- ----------------------- - | X | y1 y2 | weight | | X | y1 y2 | weight | - ----------------------- ----------------------- - | 3 | 3 3 | 0.1 | | 8 | 4 4 | 1.0 | - | 3 | 7 7 | 0.6 | ----------------------- - | 5 | 3 3 | 0.3 | |sum wt:| 1.0 | - | 5 | 8 8 | 0.3 | ----------------------- - ----------------------- - |sum wt:| 1.3 | - ----------------------- - - (5.0625 + 3.0625 + 5.0625 + 7.5625) / 4 + 0 = 5.1875 - 4 + 4.667 = 8.667 - - Impurity is found in the same way: - Left node Mean1 = Mean2 = 5.25 - error1 = ((5.25 - 3)^2 * 0.1) - + ((5.25 - 7)^2 * 0.6) - + ((5.25 - 3)^2 * 0.3) - + ((5.25 - 8)^2 * 0.3) - = 6.13125 - error_tot = 12.2625 - - Left Impurity = Total error / total weight - = 6.13125 / 1.3 - = 4.716346153846154 - or - = 12.2625 / 1.3 - = 9.43269231 - or - = 0.0 - ------------------- - - Likewise for Right node: - Right node Mean1 = Mean2 = 4 - Total error = ((4 - 4)^2 * 1.0) - = 0 - - Right Impurity = Total error / total weight - = 0 / 1.0 - = 0.0 - ------ + ------------------ """ dt_oblique = DecisionTreeRegressor(random_state=3, criterion="oblique", max_leaf_nodes=2) + dt_mse = DecisionTreeRegressor(random_state=3, criterion="mse", + max_leaf_nodes=2) # Test oblique projection where sample weights are non-uniform (as illustrated above): dt_oblique.fit(X=[[3], [5], [8], [3], [5]], y=[[3, 3], [3, 3], [4, 4], [7, 7], [8, 8]], sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) - print(dt_oblique.tree_.impurity) - assert(abs(7.7 / 2.3 - dt_oblique.tree_.impurity[0]) < 0.01 - or abs(2.0 * 7.7 / 2.3 - dt_oblique.tree_.impurity[0]) < 0.01 + + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=[0.1, 0.3, 1.0, 0.6, 0.3]) + + assert(abs(dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 or abs(dt_oblique.tree_.impurity[0]) < 0.01) - assert(abs(6.13125 / 1.3 - dt_oblique.tree_.impurity[1]) < 0.01 - or abs(2.0 * 6.13125 / 1.3 - dt_oblique.tree_.impurity[1]) < 0.01 + assert(abs(dt_mse.tree_.impurity[1] - dt_oblique.tree_.impurity[1]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[1]- dt_oblique.tree_.impurity[1]) < 0.01 or abs(dt_oblique.tree_.impurity[1]) < 0.01) assert(abs(dt_oblique.tree_.impurity[2]) < 0.01) # Test oblique projection where all sample weights are uniform: dt_oblique.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]], sample_weight=np.ones(5)) - - assert(abs(22.0 / 5.0 - dt_oblique.tree_.impurity[0]) < 0.01 - or abs(2.0 * 22.0 / 5.0 - dt_oblique.tree_.impurity[0]) < 0.01 + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8], + sample_weight=np.ones(5)) + assert(abs(dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 or abs(dt_oblique.tree_.impurity[0]) < 0.01) - assert(abs(20.75 / 4.0 - dt_oblique.tree_.impurity[1]) < 0.01 - or abs(2.0 * 20.75 / 4.0 - dt_oblique.tree_.impurity[1]) < 0.01 + assert(abs(dt_mse.tree_.impurity[1] - dt_oblique.tree_.impurity[1]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[1]- dt_oblique.tree_.impurity[1]) < 0.01 or abs(dt_oblique.tree_.impurity[1]) < 0.01) assert(abs(dt_oblique.tree_.impurity[2]) < 0.01) @@ -1995,11 +1875,12 @@ def test_oblique_proj(): # This is equivalent to providing uniform sample weights, though # the internal logic is different: dt_oblique.fit(X=[[3], [5], [8], [3], [5]], y=[[3,3], [3,3], [4,4], [7,7], [8,8]]) - assert(abs(22.0 / 5.0 - dt_oblique.tree_.impurity[0]) < 0.01 - or abs(2.0 * 22.0 / 5.0 - dt_oblique.tree_.impurity[0]) < 0.01 + dt_mse.fit(X=[[3], [5], [8], [3], [5]], y=[3, 3, 4, 7, 8]) + assert(abs(dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[0] - dt_oblique.tree_.impurity[0]) < 0.01 or abs(dt_oblique.tree_.impurity[0]) < 0.01) - assert(abs(20.75 / 4.0 - dt_oblique.tree_.impurity[1]) < 0.01 - or abs(2.0 * 20.75 / 4.0 - dt_oblique.tree_.impurity[1]) < 0.01 + assert(abs(dt_mse.tree_.impurity[1] - dt_oblique.tree_.impurity[1]) < 0.01 + or abs(2.0 * dt_mse.tree_.impurity[1]- dt_oblique.tree_.impurity[1]) < 0.01 or abs(dt_oblique.tree_.impurity[1]) < 0.01) assert(abs(dt_oblique.tree_.impurity[2]) < 0.01)