From a2ad67f49e51b50a9e9d89b5bb21fa01ac6b85d5 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 27 Sep 2019 13:36:21 -0400 Subject: [PATCH 1/4] Initial commit for seed bug in dask RF. --- cpp/src/randomforest/randomforest.cu | 15 ++++++---- cpp/src/randomforest/randomforest.hpp | 13 ++++++--- cpp/src/randomforest/randomforest_impl.cuh | 29 ++++++++++--------- cpp/src/randomforest/randomforest_impl.h | 2 +- cpp/test/sg/rf_test.cu | 4 +-- cpp/test/sg/rf_treelite_test.cu | 2 +- .../cuml/ensemble/randomforestclassifier.pyx | 7 ++++- .../cuml/ensemble/randomforestregressor.pyx | 6 +++- 8 files changed, 48 insertions(+), 30 deletions(-) diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index 59c8ed9c5c..fa82017d4a 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -150,10 +150,11 @@ void postprocess_labels(int n_rows, std::vector& labels, * @param[in] cfg_n_streams: No of parallel CUDA for training forest */ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_rows_sample, int cfg_n_streams) { + float cfg_rows_sample, int cfg_seed, int cfg_n_streams) { params.n_trees = cfg_n_trees; params.bootstrap = cfg_bootstrap; params.rows_sample = cfg_rows_sample; + params.seed = cfg_seed; params.n_streams = min(cfg_n_streams, omp_get_max_threads()); if (params.n_streams == cfg_n_streams) { std::cout << "Warning! Max setting Max streams to max openmp threads " @@ -173,11 +174,12 @@ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, * @param[in] cfg_tree_params: tree parameters */ void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_rows_sample, int cfg_n_streams, + float cfg_rows_sample, int cfg_seed, int cfg_n_streams, DecisionTree::DecisionTreeParams cfg_tree_params) { params.n_trees = cfg_n_trees; params.bootstrap = cfg_bootstrap; params.rows_sample = cfg_rows_sample; + params.seed = cfg_seed; params.n_streams = min(cfg_n_streams, omp_get_max_threads()); if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees; set_tree_params(params.tree_params); // use input tree params @@ -462,15 +464,16 @@ RF_metrics score(const cumlHandle& user_handle, RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, int n_bins, int split_algo, int min_rows_per_node, bool bootstrap_features, bool bootstrap, int n_trees, - float rows_sample, CRITERION split_criterion, - bool quantile_per_tree, int cfg_n_streams) { + float rows_sample, int seed, + CRITERION split_criterion, bool quantile_per_tree, + int cfg_n_streams) { DecisionTree::DecisionTreeParams tree_params; DecisionTree::set_tree_params( tree_params, max_depth, max_leaves, max_features, n_bins, split_algo, min_rows_per_node, bootstrap_features, split_criterion, quantile_per_tree); RF_params rf_params; - set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, cfg_n_streams, - tree_params); + set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, seed, + cfg_n_streams, tree_params); return rf_params; } diff --git a/cpp/src/randomforest/randomforest.hpp b/cpp/src/randomforest/randomforest.hpp index 75f78e6340..277aee5172 100644 --- a/cpp/src/randomforest/randomforest.hpp +++ b/cpp/src/randomforest/randomforest.hpp @@ -65,6 +65,10 @@ struct RF_params { /** * Decision tree training hyper parameter struct. */ + /** + * random seed + */ + int seed; /** * Number of concurrent GPU streams for parallel tree building. * Each stream is independently managed by CPU thread. @@ -76,9 +80,9 @@ struct RF_params { void set_rf_params(RF_params& params, int cfg_n_trees = 1, bool cfg_bootstrap = true, float cfg_rows_sample = 1.0f, - int cfg_n_streams = 8); + int cfg_seed = -1, int cfg_n_streams = 8); void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_rows_sample, int cfg_n_streams, + float cfg_rows_sample, int cfg_seed, int cfg_n_streams, DecisionTree::DecisionTreeParams cfg_tree_params); void validity_check(const RF_params rf_params); void print(const RF_params rf_params); @@ -154,8 +158,9 @@ RF_metrics score(const cumlHandle& user_handle, RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, int n_bins, int split_algo, int min_rows_per_node, bool bootstrap_features, bool bootstrap, int n_trees, - float rows_sample, CRITERION split_criterion, - bool quantile_per_tree, int cfg_n_streams); + float rows_sample, int seed, + CRITERION split_criterion, bool quantile_per_tree, + int cfg_n_streams); // ----------------------------- Regression ----------------------------------- // diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index a1de148a10..60096c2e12 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -70,9 +70,12 @@ void random_uniformInt(int treeid, unsigned int* data, int len, int n_rows, template void rf::prepare_fit_per_tree( int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows, - const int num_sms, const cudaStream_t stream, + int seed, const int num_sms, const cudaStream_t stream, const std::shared_ptr device_allocator) { - srand(tree_id * 1000); + int rs = tree_id * 1000; + if (seed != -1) rs = seed * 1000; + + srand(rs * 1000); if (rf_params.bootstrap) { random_uniformInt(tree_id, selected_rows, n_sampled_rows, n_rows, num_sms, stream); @@ -221,10 +224,10 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input, unsigned int* rowids; rowids = selected_rows[stream_id]->data(); - this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids, - tempmem[stream_id]->num_sms, - tempmem[stream_id]->stream, - handle.getDeviceAllocator()); + this->prepare_fit_per_tree( + i, n_rows, n_sampled_rows, rowids, (this->rf_params.seed + i), + tempmem[stream_id]->num_sms, tempmem[stream_id]->stream, + handle.getDeviceAllocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -236,8 +239,7 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input, */ DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]); tree_ptr->treeid = i; - trees[i].fit(handle.getDeviceAllocator(), - handle.getHostAllocator(), + trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(), tempmem[stream_id]->stream, input, n_cols, n_rows, labels, rowids, n_sampled_rows, n_unique_labels, tree_ptr, this->rf_params.tree_params, tempmem[stream_id]); @@ -485,10 +487,10 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input, for (int i = 0; i < this->rf_params.n_trees; i++) { int stream_id = omp_get_thread_num(); unsigned int* rowids = selected_rows[stream_id]->data(); - this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids, - tempmem[stream_id]->num_sms, - tempmem[stream_id]->stream, - handle.getDeviceAllocator()); + this->prepare_fit_per_tree( + i, n_rows, n_sampled_rows, rowids, (this->rf_params.seed + i), + tempmem[stream_id]->num_sms, tempmem[stream_id]->stream, + handle.getDeviceAllocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -499,8 +501,7 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input, */ DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]); tree_ptr->treeid = i; - trees[i].fit(handle.getDeviceAllocator(), - handle.getHostAllocator(), + trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(), tempmem[stream_id]->stream, input, n_cols, n_rows, labels, rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params, tempmem[stream_id]); diff --git a/cpp/src/randomforest/randomforest_impl.h b/cpp/src/randomforest/randomforest_impl.h index 3369c82ae1..cafa675838 100644 --- a/cpp/src/randomforest/randomforest_impl.h +++ b/cpp/src/randomforest/randomforest_impl.h @@ -30,7 +30,7 @@ class rf { virtual ~rf() = default; void prepare_fit_per_tree( int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows, - int num_sms, const cudaStream_t stream, + int seed, int num_sms, const cudaStream_t stream, const std::shared_ptr device_allocator); void error_checking(const T* input, L* predictions, int n_rows, int n_cols, diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index 512c53c2af..115cbd8e97 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -61,7 +61,7 @@ class RfClassifierTest : public ::testing::TestWithParam> { params.split_criterion, false); RF_params rf_params; set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.rows_sample, params.n_streams, tree_params); + params.rows_sample, -1, params.n_streams, tree_params); //print(rf_params); //-------------------------------------------------------- @@ -161,7 +161,7 @@ class RfRegressorTest : public ::testing::TestWithParam> { params.split_criterion, false); RF_params rf_params; set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.rows_sample, params.n_streams, tree_params); + params.rows_sample, -1, params.n_streams, tree_params); //print(rf_params); //-------------------------------------------------------- diff --git a/cpp/test/sg/rf_treelite_test.cu b/cpp/test/sg/rf_treelite_test.cu index d772b0c83e..279db9588e 100644 --- a/cpp/test/sg/rf_treelite_test.cu +++ b/cpp/test/sg/rf_treelite_test.cu @@ -181,7 +181,7 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam> { params.min_rows_per_node, params.bootstrap_features, params.split_criterion, false); set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.rows_sample, params.n_streams, tree_params); + params.rows_sample, -1, params.n_streams, tree_params); // print(rf_params); handle.reset(new cumlHandle(rf_params.n_streams)); diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 04f848a92d..8dc308bd9d 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -81,6 +81,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": int n_trees bool bootstrap float rows_sample + int seed pass cdef cppclass RandomForestMetaData[T, L]: @@ -181,6 +182,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": bool, int, float, + int, CRITERION, bool, int) except + @@ -302,7 +304,8 @@ class RandomForestClassifier(Base): min_samples_leaf=None, min_weight_fraction_leaf=None, max_leaf_nodes=None, min_impurity_decrease=None, min_impurity_split=None, oob_score=None, n_jobs=None, - random_state=None, warm_start=None, class_weight=None): + random_state=None, warm_start=None, class_weight=None, + int seed=-1): sklearn_params = {"criterion": criterion, "min_samples_leaf": min_samples_leaf, @@ -353,6 +356,7 @@ class RandomForestClassifier(Base): self.quantile_per_tree = quantile_per_tree self.n_cols = None self.n_streams = n_streams + self.seed = seed cdef RandomForestMetaData[float, int] *rf_forest = \ new RandomForestMetaData[float, int]() @@ -497,6 +501,7 @@ class RandomForestClassifier(Base): self.bootstrap, self.n_estimators, self.rows_sample, + self.seed, self.split_criterion, self.quantile_per_tree, self.n_streams) diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index e6727a47a8..f86ff332ad 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -80,6 +80,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": int n_trees bool bootstrap float rows_sample + int seed pass cdef cppclass RandomForestMetaData[T, L]: @@ -162,6 +163,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": bool, int, float, + int, CRITERION, bool, int) except + @@ -286,7 +288,7 @@ class RandomForestRegressor(Base): max_leaf_nodes=None, min_impurity_decrease=None, min_impurity_split=None, oob_score=None, random_state=None, warm_start=None, class_weight=None, - quantile_per_tree=False, criterion=None): + quantile_per_tree=False, criterion=None, int seed=-1): sklearn_params = {"criterion": criterion, "min_samples_leaf": min_samples_leaf, @@ -337,6 +339,7 @@ class RandomForestRegressor(Base): self.accuracy_metric = accuracy_metric self.quantile_per_tree = quantile_per_tree self.n_streams = n_streams + self.seed = seed cdef RandomForestMetaData[float, float] *rf_forest = \ new RandomForestMetaData[float, float]() @@ -461,6 +464,7 @@ class RandomForestRegressor(Base): self.bootstrap, self.n_estimators, self.rows_sample, + self.seed, self.split_criterion, self.quantile_per_tree, self.n_streams) From ebf4233249755361142572129e35d66c186f761c Mon Sep 17 00:00:00 2001 From: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com> Date: Fri, 27 Sep 2019 13:39:24 -0400 Subject: [PATCH 2/4] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8bc4a6279..85a841dd55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ - PR #1106: Pinning Distributed version to match Dask for consistent CI results - PR #1116: TSNE CUDA 10.1 Bug Fixes - PR #1132: DBSCAN Batching Bug Fix +- PR #1162: DASK RF random seed bug fix # cuML 0.9.0 (21 Aug 2019) From 0b6db552fcadfe139cab44280bc9ab20e6877e92 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Sat, 28 Sep 2019 11:56:13 -0400 Subject: [PATCH 3/4] Dask RF updated to fix the random seed bug. --- cpp/src/randomforest/randomforest_impl.cuh | 16 +++++++--------- cpp/src/randomforest/randomforest_impl.h | 2 +- .../cuml/dask/ensemble/randomforestclassifier.py | 9 +++++++++ .../cuml/dask/ensemble/randomforestregressor.py | 9 +++++++++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index 60096c2e12..471e6da754 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -70,10 +70,10 @@ void random_uniformInt(int treeid, unsigned int* data, int len, int n_rows, template void rf::prepare_fit_per_tree( int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows, - int seed, const int num_sms, const cudaStream_t stream, + const int num_sms, const cudaStream_t stream, const std::shared_ptr device_allocator) { - int rs = tree_id * 1000; - if (seed != -1) rs = seed * 1000; + int rs = tree_id; + if (rf_params.seed > -1) rs = rf_params.seed + tree_id; srand(rs * 1000); if (rf_params.bootstrap) { @@ -225,9 +225,8 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input, rowids = selected_rows[stream_id]->data(); this->prepare_fit_per_tree( - i, n_rows, n_sampled_rows, rowids, (this->rf_params.seed + i), - tempmem[stream_id]->num_sms, tempmem[stream_id]->stream, - handle.getDeviceAllocator()); + i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms, + tempmem[stream_id]->stream, handle.getDeviceAllocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -488,9 +487,8 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input, int stream_id = omp_get_thread_num(); unsigned int* rowids = selected_rows[stream_id]->data(); this->prepare_fit_per_tree( - i, n_rows, n_sampled_rows, rowids, (this->rf_params.seed + i), - tempmem[stream_id]->num_sms, tempmem[stream_id]->stream, - handle.getDeviceAllocator()); + i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms, + tempmem[stream_id]->stream, handle.getDeviceAllocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. diff --git a/cpp/src/randomforest/randomforest_impl.h b/cpp/src/randomforest/randomforest_impl.h index cafa675838..3369c82ae1 100644 --- a/cpp/src/randomforest/randomforest_impl.h +++ b/cpp/src/randomforest/randomforest_impl.h @@ -30,7 +30,7 @@ class rf { virtual ~rf() = default; void prepare_fit_per_tree( int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows, - int seed, int num_sms, const cudaStream_t stream, + int num_sms, const cudaStream_t stream, const std::shared_ptr device_allocator); void error_checking(const T* input, L* predictions, int n_rows, int n_cols, diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index 9e063e18a4..8187af137a 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -194,6 +194,12 @@ def __init__( self.n_estimators_per_worker[i] + 1 ) + seeds = list() + seeds.append(0) + for i in range(1, len(self.n_estimators_per_worker)): + sd = self.n_estimators_per_worker[i-1] + seeds[i-1] + seeds.append(sd) + key = str(uuid1()) self.rfs = { worker: c.submit( @@ -213,6 +219,7 @@ def __init__( rows_sample, max_leaves, quantile_per_tree, + seeds[n], dtype, key="%s-%s" % (key, n), workers=[worker], @@ -243,6 +250,7 @@ def _func_build_rf( rows_sample, max_leaves, quantile_per_tree, + seed, dtype, ): return cuRFC( @@ -262,6 +270,7 @@ def _func_build_rf( max_leaves=max_leaves, n_streams=n_streams, quantile_per_tree=quantile_per_tree, + seed=seed, gdf_datatype=dtype, ) diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index b86e1a9269..e12bb75056 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -197,6 +197,12 @@ def __init__( self.n_estimators_per_worker[i] + 1 ) + seeds = list() + seeds.append(0) + for i in range(1, len(self.n_estimators_per_worker)): + sd = self.n_estimators_per_worker[i-1] + seeds[i-1] + seeds.append(sd) + key = str(uuid1()) self.rfs = { worker: c.submit( @@ -216,6 +222,7 @@ def __init__( max_leaves, accuracy_metric, quantile_per_tree, + seeds[n], key="%s-%s" % (key, n), workers=[worker], ) @@ -245,6 +252,7 @@ def _func_build_rf( max_leaves, accuracy_metric, quantile_per_tree, + seed, ): return cuRFR( @@ -264,6 +272,7 @@ def _func_build_rf( n_streams=n_streams, accuracy_metric=accuracy_metric, quantile_per_tree=quantile_per_tree, + seed=seed, ) @staticmethod From 0eacd14a382869208ebc04d30dda61a2c0dc0d53 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Sat, 28 Sep 2019 12:02:49 -0400 Subject: [PATCH 4/4] Fixed the formatting issues. --- python/cuml/ensemble/randomforestclassifier.pyx | 4 ++-- python/cuml/ensemble/randomforestregressor.pyx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 8dc308bd9d..5d7aadb656 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -304,8 +304,8 @@ class RandomForestClassifier(Base): min_samples_leaf=None, min_weight_fraction_leaf=None, max_leaf_nodes=None, min_impurity_decrease=None, min_impurity_split=None, oob_score=None, n_jobs=None, - random_state=None, warm_start=None, class_weight=None, - int seed=-1): + random_state=None, warm_start=None, class_weight=None, + seed=-1): sklearn_params = {"criterion": criterion, "min_samples_leaf": min_samples_leaf, diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index f86ff332ad..b3cd2f2bf6 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -288,7 +288,7 @@ class RandomForestRegressor(Base): max_leaf_nodes=None, min_impurity_decrease=None, min_impurity_split=None, oob_score=None, random_state=None, warm_start=None, class_weight=None, - quantile_per_tree=False, criterion=None, int seed=-1): + quantile_per_tree=False, criterion=None, seed=-1): sklearn_params = {"criterion": criterion, "min_samples_leaf": min_samples_leaf,