Skip to content

Commit

Permalink
Merge branch 'branch-0.10' into patch-12
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Oct 1, 2019
2 parents c0b8cac + 28a8dd7 commit e3694c8
Show file tree
Hide file tree
Showing 37 changed files with 328 additions and 177 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
- PR #1086: Ensure RegressorMixin scorer uses device arrays
- PR #1108: input_to_host_array function in input_utils for input processing to host arrays
- PR #1114: K-means: Exposing useful params, removing unused params, proxying params in Dask
- PR #1142: prims: expose separate InType and OutType for unaryOp and binaryOp
- PR #1115: Moving dask_make_blobs to cuml.dask.datasets. Adding conversion to dask.DataFrame
- PR #1136: CUDA 10.1 CI updates
- PR #1165: Adding except + in all remaining cython
- PR #1173: Docs: Barnes Hut TSNE documentation
- PR #1176: Use new RMM API based on Cython

## Bug Fixes

Expand All @@ -44,6 +46,7 @@
- PR #1106: Pinning Distributed version to match Dask for consistent CI results
- PR #1116: TSNE CUDA 10.1 Bug Fixes
- PR #1132: DBSCAN Batching Bug Fix
- PR #1162: DASK RF random seed bug fix
- PR #1164: Fix check_dtype arg handling for input_to_dev_array

# cuML 0.9.0 (21 Aug 2019)
Expand Down Expand Up @@ -105,6 +108,7 @@
- PR #978: Update README for 0.9
- PR #1009: Fix references to notebooks-contrib
- PR #1015: Ability to control the number of internal streams in cumlHandle_impl via cumlHandle
- PR #1175: Add more modules to docs ToC

## Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ repo](https://github.com/rapidsai/notebooks-contrib).
| **Nonlinear Models for Regression or Classification** | Random Forest (RF) Classification | Experimental multi-node, multi-GPU version available via Dask integration |
| | Random Forest (RF) Regression | Experimental multi-node, multi-GPU version available via Dask integration |
| | K-Nearest Neighbors (KNN) | Multi-GPU <br> Uses [Faiss](https://github.com/facebookresearch/faiss) |
| | Support Vector Machine Classifier (SVC) | |
| **Time Series** | Linear Kalman Filter | |
| | Holt-Winters Exponential Smoothing | |
---
Expand Down
15 changes: 9 additions & 6 deletions cpp/src/randomforest/randomforest.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ void postprocess_labels(int n_rows, std::vector<int>& labels,
* @param[in] cfg_n_streams: No of parallel CUDA for training forest
*/
void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_rows_sample, int cfg_n_streams) {
float cfg_rows_sample, int cfg_seed, int cfg_n_streams) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.rows_sample = cfg_rows_sample;
params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (params.n_streams == cfg_n_streams) {
std::cout << "Warning! Max setting Max streams to max openmp threads "
Expand All @@ -173,11 +174,12 @@ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
* @param[in] cfg_tree_params: tree parameters
*/
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_rows_sample, int cfg_n_streams,
float cfg_rows_sample, int cfg_seed, int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.rows_sample = cfg_rows_sample;
params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees;
set_tree_params(params.tree_params); // use input tree params
Expand Down Expand Up @@ -462,15 +464,16 @@ RF_metrics score(const cumlHandle& user_handle,
RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_rows_per_node,
bool bootstrap_features, bool bootstrap, int n_trees,
float rows_sample, CRITERION split_criterion,
bool quantile_per_tree, int cfg_n_streams) {
float rows_sample, int seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams) {
DecisionTree::DecisionTreeParams tree_params;
DecisionTree::set_tree_params(
tree_params, max_depth, max_leaves, max_features, n_bins, split_algo,
min_rows_per_node, bootstrap_features, split_criterion, quantile_per_tree);
RF_params rf_params;
set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, cfg_n_streams,
tree_params);
set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, seed,
cfg_n_streams, tree_params);
return rf_params;
}

Expand Down
13 changes: 9 additions & 4 deletions cpp/src/randomforest/randomforest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ struct RF_params {
/**
* Decision tree training hyper parameter struct.
*/
/**
* random seed
*/
int seed;
/**
* Number of concurrent GPU streams for parallel tree building.
* Each stream is independently managed by CPU thread.
Expand All @@ -76,9 +80,9 @@ struct RF_params {

void set_rf_params(RF_params& params, int cfg_n_trees = 1,
bool cfg_bootstrap = true, float cfg_rows_sample = 1.0f,
int cfg_n_streams = 8);
int cfg_seed = -1, int cfg_n_streams = 8);
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_rows_sample, int cfg_n_streams,
float cfg_rows_sample, int cfg_seed, int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params);
void validity_check(const RF_params rf_params);
void print(const RF_params rf_params);
Expand Down Expand Up @@ -154,8 +158,9 @@ RF_metrics score(const cumlHandle& user_handle,
RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_rows_per_node,
bool bootstrap_features, bool bootstrap, int n_trees,
float rows_sample, CRITERION split_criterion,
bool quantile_per_tree, int cfg_n_streams);
float rows_sample, int seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams);

// ----------------------------- Regression ----------------------------------- //

Expand Down
25 changes: 12 additions & 13 deletions cpp/src/randomforest/randomforest_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ void rf<T, L>::prepare_fit_per_tree(
int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows,
const int num_sms, const cudaStream_t stream,
const std::shared_ptr<deviceAllocator> device_allocator) {
srand(tree_id * 1000);
int rs = tree_id;
if (rf_params.seed > -1) rs = rf_params.seed + tree_id;

srand(rs * 1000);
if (rf_params.bootstrap) {
random_uniformInt(tree_id, selected_rows, n_sampled_rows, n_rows, num_sms,
stream);
Expand Down Expand Up @@ -221,10 +224,9 @@ void rfClassifier<T>::fit(const cumlHandle& user_handle, const T* input,
unsigned int* rowids;
rowids = selected_rows[stream_id]->data();

this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids,
tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream,
handle.getDeviceAllocator());
this->prepare_fit_per_tree(
i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream, handle.getDeviceAllocator());

/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
Expand All @@ -236,8 +238,7 @@ void rfClassifier<T>::fit(const cumlHandle& user_handle, const T* input,
*/
DecisionTree::TreeMetaDataNode<T, int>* tree_ptr = &(forest->trees[i]);
tree_ptr->treeid = i;
trees[i].fit(handle.getDeviceAllocator(),
handle.getHostAllocator(),
trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(),
tempmem[stream_id]->stream, input, n_cols, n_rows, labels,
rowids, n_sampled_rows, n_unique_labels, tree_ptr,
this->rf_params.tree_params, tempmem[stream_id]);
Expand Down Expand Up @@ -485,10 +486,9 @@ void rfRegressor<T>::fit(const cumlHandle& user_handle, const T* input,
for (int i = 0; i < this->rf_params.n_trees; i++) {
int stream_id = omp_get_thread_num();
unsigned int* rowids = selected_rows[stream_id]->data();
this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids,
tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream,
handle.getDeviceAllocator());
this->prepare_fit_per_tree(
i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms,
tempmem[stream_id]->stream, handle.getDeviceAllocator());

/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
Expand All @@ -499,8 +499,7 @@ void rfRegressor<T>::fit(const cumlHandle& user_handle, const T* input,
*/
DecisionTree::TreeMetaDataNode<T, T>* tree_ptr = &(forest->trees[i]);
tree_ptr->treeid = i;
trees[i].fit(handle.getDeviceAllocator(),
handle.getHostAllocator(),
trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(),
tempmem[stream_id]->stream, input, n_cols, n_rows, labels,
rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params,
tempmem[stream_id]);
Expand Down
70 changes: 39 additions & 31 deletions cpp/src_prims/linalg/binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,41 @@
namespace MLCommon {
namespace LinAlg {

template <typename math_t, int veclen_, typename Lambda, typename IdxType>
__global__ void binaryOpKernel(math_t *out, const math_t *in1,
const math_t *in2, IdxType len, Lambda op) {
typedef TxN_t<math_t, veclen_> VecType;
VecType a, b;
template <typename InType, int VecLen, typename Lambda, typename IdxType,
typename OutType>
__global__ void binaryOpKernel(OutType *out, const InType *in1,
const InType *in2, IdxType len, Lambda op) {
typedef TxN_t<InType, VecLen> InVecType;
typedef TxN_t<OutType, VecLen> OutVecType;
InVecType a, b;
OutVecType c;
IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * blockDim.x);
idx *= VecType::Ratio;
idx *= InVecType::Ratio;
if (idx >= len) return;
a.load(in1, idx);
b.load(in2, idx);
#pragma unroll
for (int i = 0; i < VecType::Ratio; ++i) {
a.val.data[i] = op(a.val.data[i], b.val.data[i]);
for (int i = 0; i < InVecType::Ratio; ++i) {
c.val.data[i] = op(a.val.data[i], b.val.data[i]);
}
a.store(out, idx);
c.store(out, idx);
}

template <typename math_t, int veclen_, typename Lambda, typename IdxType,
int TPB>
void binaryOpImpl(math_t *out, const math_t *in1, const math_t *in2,
template <typename InType, int VecLen, typename Lambda, typename IdxType,
typename OutType, int TPB>
void binaryOpImpl(OutType *out, const InType *in1, const InType *in2,
IdxType len, Lambda op, cudaStream_t stream) {
const IdxType nblks = ceildiv(veclen_ ? len / veclen_ : len, (IdxType)TPB);
binaryOpKernel<math_t, veclen_, Lambda, IdxType>
const IdxType nblks = ceildiv(VecLen ? len / VecLen : len, (IdxType)TPB);
binaryOpKernel<InType, VecLen, Lambda, IdxType, OutType>
<<<nblks, TPB, 0, stream>>>(out, in1, in2, len, op);
CUDA_CHECK(cudaPeekAtLastError());
}

/**
* @brief perform element-wise binary operation on the input arrays
* @tparam math_t data-type upon which the math operation will be performed
* @tparam InType input data-type
* @tparam Lambda the device-lambda performing the actual operation
* @tparam OutType output data-type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads-per-block in the final kernel launched
* @param out the output array
Expand All @@ -61,30 +65,34 @@ void binaryOpImpl(math_t *out, const math_t *in1, const math_t *in2,
* @param len number of elements in the input array
* @param op the device-lambda
* @param stream cuda stream where to launch work
* @note Lambda must be a functor with the following signature:
* `OutType func(const InType& val1, const InType& val2);`
*/
template <typename math_t, typename Lambda, typename IdxType = int,
int TPB = 256>
void binaryOp(math_t *out, const math_t *in1, const math_t *in2, IdxType len,
template <typename InType, typename Lambda, typename OutType = InType,
typename IdxType = int, int TPB = 256>
void binaryOp(OutType *out, const InType *in1, const InType *in2, IdxType len,
Lambda op, cudaStream_t stream) {
size_t bytes = len * sizeof(math_t);
if (16 / sizeof(math_t) && bytes % 16 == 0) {
binaryOpImpl<math_t, 16 / sizeof(math_t), Lambda, IdxType, TPB>(
constexpr auto maxSize =
sizeof(InType) > sizeof(OutType) ? sizeof(InType) : sizeof(OutType);
size_t bytes = len * maxSize;
if (16 / maxSize && bytes % 16 == 0) {
binaryOpImpl<InType, 16 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (8 / sizeof(math_t) && bytes % 8 == 0) {
binaryOpImpl<math_t, 8 / sizeof(math_t), Lambda, IdxType, TPB>(
} else if (8 / maxSize && bytes % 8 == 0) {
binaryOpImpl<InType, 8 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (4 / sizeof(math_t) && bytes % 4 == 0) {
binaryOpImpl<math_t, 4 / sizeof(math_t), Lambda, IdxType, TPB>(
} else if (4 / maxSize && bytes % 4 == 0) {
binaryOpImpl<InType, 4 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (2 / sizeof(math_t) && bytes % 2 == 0) {
binaryOpImpl<math_t, 2 / sizeof(math_t), Lambda, IdxType, TPB>(
} else if (2 / maxSize && bytes % 2 == 0) {
binaryOpImpl<InType, 2 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (1 / sizeof(math_t)) {
binaryOpImpl<math_t, 1 / sizeof(math_t), Lambda, IdxType, TPB>(
} else if (1 / maxSize) {
binaryOpImpl<InType, 1 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else {
binaryOpImpl<math_t, 1, Lambda, IdxType, TPB>(out, in1, in2, len, op,
stream);
binaryOpImpl<InType, 1, Lambda, IdxType, OutType, TPB>(out, in1, in2, len,
op, stream);
}
}

Expand Down
Loading

0 comments on commit e3694c8

Please sign in to comment.