Skip to content

Commit

Permalink
Merge pull request rapidsai#1162 from oyilmaz-nvidia/fea-rf-rnd-seed
Browse files Browse the repository at this point in the history
[REVIEW] To fix DASK RF random seed bug in issue rapidsai#1050
  • Loading branch information
JohnZed authored Oct 1, 2019
2 parents 1f8588a + 3bb74dc commit 28a8dd7
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
- PR #1106: Pinning Distributed version to match Dask for consistent CI results
- PR #1116: TSNE CUDA 10.1 Bug Fixes
- PR #1132: DBSCAN Batching Bug Fix
- PR #1162: DASK RF random seed bug fix
- PR #1164: Fix check_dtype arg handling for input_to_dev_array

# cuML 0.9.0 (21 Aug 2019)
Expand Down
15 changes: 9 additions & 6 deletions cpp/src/randomforest/randomforest.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ void postprocess_labels(int n_rows, std::vector<int>& labels,
* @param[in] cfg_n_streams: No of parallel CUDA for training forest
*/
void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_rows_sample, int cfg_n_streams) {
float cfg_rows_sample, int cfg_seed, int cfg_n_streams) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.rows_sample = cfg_rows_sample;
params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (params.n_streams == cfg_n_streams) {
std::cout << "Warning! Max setting Max streams to max openmp threads "
Expand All @@ -173,11 +174,12 @@ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
* @param[in] cfg_tree_params: tree parameters
*/
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_rows_sample, int cfg_n_streams,
float cfg_rows_sample, int cfg_seed, int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.rows_sample = cfg_rows_sample;
params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees;
set_tree_params(params.tree_params); // use input tree params
Expand Down Expand Up @@ -462,15 +464,16 @@ RF_metrics score(const cumlHandle& user_handle,
RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_rows_per_node,
bool bootstrap_features, bool bootstrap, int n_trees,
float rows_sample, CRITERION split_criterion,
bool quantile_per_tree, int cfg_n_streams) {
float rows_sample, int seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams) {
DecisionTree::DecisionTreeParams tree_params;
DecisionTree::set_tree_params(
tree_params, max_depth, max_leaves, max_features, n_bins, split_algo,
min_rows_per_node, bootstrap_features, split_criterion, quantile_per_tree);
RF_params rf_params;
set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, cfg_n_streams,
tree_params);
set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, seed,
cfg_n_streams, tree_params);
return rf_params;
}

Expand Down
13 changes: 9 additions & 4 deletions cpp/src/randomforest/randomforest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ struct RF_params {
/**
* Decision tree training hyper parameter struct.
*/
/**
* random seed
*/
int seed;
/**
* Number of concurrent GPU streams for parallel tree building.
* Each stream is independently managed by CPU thread.
Expand All @@ -76,9 +80,9 @@ struct RF_params {

void set_rf_params(RF_params& params, int cfg_n_trees = 1,
bool cfg_bootstrap = true, float cfg_rows_sample = 1.0f,
int cfg_n_streams = 8);
int cfg_seed = -1, int cfg_n_streams = 8);
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_rows_sample, int cfg_n_streams,
float cfg_rows_sample, int cfg_seed, int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params);
void validity_check(const RF_params rf_params);
void print(const RF_params rf_params);
Expand Down Expand Up @@ -154,8 +158,9 @@ RF_metrics score(const cumlHandle& user_handle,
RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_rows_per_node,
bool bootstrap_features, bool bootstrap, int n_trees,
float rows_sample, CRITERION split_criterion,
bool quantile_per_tree, int cfg_n_streams);
float rows_sample, int seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams);

// ----------------------------- Regression ----------------------------------- //

Expand Down
25 changes: 12 additions & 13 deletions cpp/src/randomforest/randomforest_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ void rf<T, L>::prepare_fit_per_tree(
int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows,
const int num_sms, const cudaStream_t stream,
const std::shared_ptr<deviceAllocator> device_allocator) {
srand(tree_id * 1000);
int rs = tree_id;
if (rf_params.seed > -1) rs = rf_params.seed + tree_id;

srand(rs * 1000);
if (rf_params.bootstrap) {
random_uniformInt(tree_id, selected_rows, n_sampled_rows, n_rows, num_sms,
stream);
Expand Down Expand Up @@ -221,10 +224,9 @@ void rfClassifier<T>::fit(const cumlHandle& user_handle, const T* input,
unsigned int* rowids;
rowids = selected_rows[stream_id]->data();

this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids,
tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream,
handle.getDeviceAllocator());
this->prepare_fit_per_tree(
i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream, handle.getDeviceAllocator());

/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
Expand All @@ -236,8 +238,7 @@ void rfClassifier<T>::fit(const cumlHandle& user_handle, const T* input,
*/
DecisionTree::TreeMetaDataNode<T, int>* tree_ptr = &(forest->trees[i]);
tree_ptr->treeid = i;
trees[i].fit(handle.getDeviceAllocator(),
handle.getHostAllocator(),
trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(),
tempmem[stream_id]->stream, input, n_cols, n_rows, labels,
rowids, n_sampled_rows, n_unique_labels, tree_ptr,
this->rf_params.tree_params, tempmem[stream_id]);
Expand Down Expand Up @@ -485,10 +486,9 @@ void rfRegressor<T>::fit(const cumlHandle& user_handle, const T* input,
for (int i = 0; i < this->rf_params.n_trees; i++) {
int stream_id = omp_get_thread_num();
unsigned int* rowids = selected_rows[stream_id]->data();
this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids,
tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream,
handle.getDeviceAllocator());
this->prepare_fit_per_tree(
i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream, handle.getDeviceAllocator());

/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
Expand All @@ -499,8 +499,7 @@ void rfRegressor<T>::fit(const cumlHandle& user_handle, const T* input,
*/
DecisionTree::TreeMetaDataNode<T, T>* tree_ptr = &(forest->trees[i]);
tree_ptr->treeid = i;
trees[i].fit(handle.getDeviceAllocator(),
handle.getHostAllocator(),
trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(),
tempmem[stream_id]->stream, input, n_cols, n_rows, labels,
rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params,
tempmem[stream_id]);
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/sg/rf_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class RfClassifierTest : public ::testing::TestWithParam<RfInputs<T>> {
params.split_criterion, false);
RF_params rf_params;
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
params.rows_sample, params.n_streams, tree_params);
params.rows_sample, -1, params.n_streams, tree_params);
//print(rf_params);

//--------------------------------------------------------
Expand Down Expand Up @@ -161,7 +161,7 @@ class RfRegressorTest : public ::testing::TestWithParam<RfInputs<T>> {
params.split_criterion, false);
RF_params rf_params;
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
params.rows_sample, params.n_streams, tree_params);
params.rows_sample, -1, params.n_streams, tree_params);
//print(rf_params);

//--------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/sg/rf_treelite_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam<RfInputs<T>> {
params.min_rows_per_node, params.bootstrap_features,
params.split_criterion, false);
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
params.rows_sample, params.n_streams, tree_params);
params.rows_sample, -1, params.n_streams, tree_params);
// print(rf_params);
handle.reset(new cumlHandle(rf_params.n_streams));

Expand Down
9 changes: 9 additions & 0 deletions python/cuml/dask/ensemble/randomforestclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ def __init__(
self.n_estimators_per_worker[i] + 1
)

seeds = list()
seeds.append(0)
for i in range(1, len(self.n_estimators_per_worker)):
sd = self.n_estimators_per_worker[i-1] + seeds[i-1]
seeds.append(sd)

key = str(uuid1())
self.rfs = {
worker: c.submit(
Expand All @@ -213,6 +219,7 @@ def __init__(
rows_sample,
max_leaves,
quantile_per_tree,
seeds[n],
dtype,
key="%s-%s" % (key, n),
workers=[worker],
Expand Down Expand Up @@ -243,6 +250,7 @@ def _func_build_rf(
rows_sample,
max_leaves,
quantile_per_tree,
seed,
dtype,
):
return cuRFC(
Expand All @@ -262,6 +270,7 @@ def _func_build_rf(
max_leaves=max_leaves,
n_streams=n_streams,
quantile_per_tree=quantile_per_tree,
seed=seed,
gdf_datatype=dtype,
)

Expand Down
9 changes: 9 additions & 0 deletions python/cuml/dask/ensemble/randomforestregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def __init__(
self.n_estimators_per_worker[i] + 1
)

seeds = list()
seeds.append(0)
for i in range(1, len(self.n_estimators_per_worker)):
sd = self.n_estimators_per_worker[i-1] + seeds[i-1]
seeds.append(sd)

key = str(uuid1())
self.rfs = {
worker: c.submit(
Expand All @@ -216,6 +222,7 @@ def __init__(
max_leaves,
accuracy_metric,
quantile_per_tree,
seeds[n],
key="%s-%s" % (key, n),
workers=[worker],
)
Expand Down Expand Up @@ -245,6 +252,7 @@ def _func_build_rf(
max_leaves,
accuracy_metric,
quantile_per_tree,
seed,
):

return cuRFR(
Expand All @@ -264,6 +272,7 @@ def _func_build_rf(
n_streams=n_streams,
accuracy_metric=accuracy_metric,
quantile_per_tree=quantile_per_tree,
seed=seed,
)

@staticmethod
Expand Down
7 changes: 6 additions & 1 deletion python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
int n_trees
bool bootstrap
float rows_sample
int seed
pass

cdef cppclass RandomForestMetaData[T, L]:
Expand Down Expand Up @@ -181,6 +182,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
bool,
int,
float,
int,
CRITERION,
bool,
int) except +
Expand Down Expand Up @@ -302,7 +304,8 @@ class RandomForestClassifier(Base):
min_samples_leaf=None, min_weight_fraction_leaf=None,
max_leaf_nodes=None, min_impurity_decrease=None,
min_impurity_split=None, oob_score=None, n_jobs=None,
random_state=None, warm_start=None, class_weight=None):
random_state=None, warm_start=None, class_weight=None,
seed=-1):

sklearn_params = {"criterion": criterion,
"min_samples_leaf": min_samples_leaf,
Expand Down Expand Up @@ -353,6 +356,7 @@ class RandomForestClassifier(Base):
self.quantile_per_tree = quantile_per_tree
self.n_cols = None
self.n_streams = n_streams
self.seed = seed

cdef RandomForestMetaData[float, int] *rf_forest = \
new RandomForestMetaData[float, int]()
Expand Down Expand Up @@ -497,6 +501,7 @@ class RandomForestClassifier(Base):
<bool> self.bootstrap,
<int> self.n_estimators,
<float> self.rows_sample,
<int> self.seed,
<CRITERION> self.split_criterion,
<bool> self.quantile_per_tree,
<int> self.n_streams)
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
int n_trees
bool bootstrap
float rows_sample
int seed
pass

cdef cppclass RandomForestMetaData[T, L]:
Expand Down Expand Up @@ -162,6 +163,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
bool,
int,
float,
int,
CRITERION,
bool,
int) except +
Expand Down Expand Up @@ -286,7 +288,7 @@ class RandomForestRegressor(Base):
max_leaf_nodes=None, min_impurity_decrease=None,
min_impurity_split=None, oob_score=None,
random_state=None, warm_start=None, class_weight=None,
quantile_per_tree=False, criterion=None):
quantile_per_tree=False, criterion=None, seed=-1):

sklearn_params = {"criterion": criterion,
"min_samples_leaf": min_samples_leaf,
Expand Down Expand Up @@ -337,6 +339,7 @@ class RandomForestRegressor(Base):
self.accuracy_metric = accuracy_metric
self.quantile_per_tree = quantile_per_tree
self.n_streams = n_streams
self.seed = seed

cdef RandomForestMetaData[float, float] *rf_forest = \
new RandomForestMetaData[float, float]()
Expand Down Expand Up @@ -461,6 +464,7 @@ class RandomForestRegressor(Base):
<bool> self.bootstrap,
<int> self.n_estimators,
<float> self.rows_sample,
<int> self.seed,
<CRITERION> self.split_criterion,
<bool> self.quantile_per_tree,
<int> self.n_streams)
Expand Down

0 comments on commit 28a8dd7

Please sign in to comment.