-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why does the regression model produced by XGBoost depend on the order of the training data when more than 8194 data points are used? #10834
Comments
Can confirm the finding (thanks for the very detailed report!), and made the following additional observations:
|
Yes, XGBoost is order-dependent when there is more than 1 thread. This is caused by floating point summation not being associative:
In a parallel environment, we split the data into chunks for multiple threads to consume:
|
Even with a single thread, it's still order-dependent. The issue is just whether the floating point error can accumulate to a point where it can affect the tree splits. |
Thanks @maxaehle for the good ideas. I can reproduce the first observation, but not the second; when I set @trivialfis, I appreciate the clarifications. However, I am skeptical that the nonassociativity of floating-point summation can explain the major transition in order dependence that I have described:
The threshold value 8194 is quite nearly |
The paper is a bit outdated for implementation. Floating point is a main source of such issues,but there is other source. The quantile sketching works on stream of data and prunes the summary as more input comes in. In such case, prune results can be dependent on the arrival order of the data. You can try to verify this by using Quantile DMatrix with get cut. If quantile is the the issue, then it's very likely just floating point errors, the gain calculation is very sensitive there. |
Is the pruning method of the quantile sketching algorithm used only above The verification that you mention using
This produces 4 blocks of output. The first line of each block is
Is floating-point error adequate to explain the differences between the third and fourth lines? We can investigate this in a casual manner by making a list of
Running this code yields
So I think we can expect floating-point errors to be absolutely minute, and for floating-point errors to be the source of the discrepancies, we would need to be computing a difference between two quantities that are almost identical (that is, identical out to about 14 decimal places). We would not be doing so on the first round of boosting, right? If I am missing anything let me know. |
Hi @6nc0r6-1mp6r0 @maxaehle . Finally got the time to experiment. Apologies for the slow reply here. I simplified the example to: import numpy as np
import xgboost as xgb
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=10244, random_state=1024)
Xy = xgb.QuantileDMatrix(X, label=y)
print(Xy.get_quantile_cut())
p = np.random.permutation(X.shape[0])
X = X[p]
y = y[p]
Xy = xgb.QuantileDMatrix(X, label=y)
print(Xy.get_quantile_cut()) You can see the difference in the cut values:
Small differences, yet. But these are defined by the input data, cut points are actual values in the dataset and they define the tree splits. As a result, one can obtain entirely different data partitions from different decision trees. It's one of the issues of decision trees. It's sensitive to small changes in data and, hence, easy to overfit. In general, most numeric libraries are order-dependent; the issue is more about how this affects accuracy. |
When I use$\leq$ 8194, where
XGBRegressor
to construct a boosted tree model from 8194 or fewer data points (i.e.,n_train
n_train
is defined in the code below) and randomly shuffle the data points before training, thefit
method is order independent, meaning that it generates the same predictive model each time that it is called. However, when I do the same for 8195 data points,fit
is order dependent -- it generates a different predictive model for each call. Why is this?I have read this paper on XGBoost and nearly all of the XGBoost documentation, and the non-subsampling algorithms described in both appear to be order independent for all
n_train
. So the source of the order dependence for large-n_train
datasets is the mysterious part.Below is a minimal Python script that illustrates the issue.
One run of this code yields
Note that for
n_train
= 8194,y_test_pred[m][:n_disp]
is the same for allm
, but forn_train
= 8195 it is not.Within the script, observe that I permute the elements of
X_train
andy_train
before each run. I would expect this to have no effect on the model produced by the fitting algorithm given that, to my understanding, the feature values are sorted and binned near the start of the algorithm. However, if I comment out this permutation, the high-n_train
order dependence of the algorithm disappears. Also note that within theXGBRegressor
call,tree_method
can be set to'approx'
,'hist'
, or'auto'
andrandom_state
can be set to a fixed value without eliminating the order dependence at largen_train
.Finally, there are several comments in the XGBoost documentation that might initially seem relevant:
For various reasons, however, I suspect that these notes are either unrelated to or inadequate to explain the abrupt transition to order dependence that I have just described.
Any clarity would be appreciated!
The text was updated successfully, but these errors were encountered: