diff --git a/CHANGELOG.md b/CHANGELOG.md
index 51c3016454..a68c04ece0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -22,10 +22,12 @@
- PR #1086: Ensure RegressorMixin scorer uses device arrays
- PR #1108: input_to_host_array function in input_utils for input processing to host arrays
- PR #1114: K-means: Exposing useful params, removing unused params, proxying params in Dask
+- PR #1142: prims: expose separate InType and OutType for unaryOp and binaryOp
- PR #1115: Moving dask_make_blobs to cuml.dask.datasets. Adding conversion to dask.DataFrame
- PR #1136: CUDA 10.1 CI updates
- PR #1165: Adding except + in all remaining cython
- PR #1173: Docs: Barnes Hut TSNE documentation
+- PR #1176: Use new RMM API based on Cython
## Bug Fixes
@@ -44,6 +46,7 @@
- PR #1106: Pinning Distributed version to match Dask for consistent CI results
- PR #1116: TSNE CUDA 10.1 Bug Fixes
- PR #1132: DBSCAN Batching Bug Fix
+- PR #1162: DASK RF random seed bug fix
- PR #1164: Fix check_dtype arg handling for input_to_dev_array
# cuML 0.9.0 (21 Aug 2019)
@@ -105,6 +108,7 @@
- PR #978: Update README for 0.9
- PR #1009: Fix references to notebooks-contrib
- PR #1015: Ability to control the number of internal streams in cumlHandle_impl via cumlHandle
+- PR #1175: Add more modules to docs ToC
## Bug Fixes
diff --git a/README.md b/README.md
index ec7c281282..15523518f9 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,7 @@ repo](https://github.com/rapidsai/notebooks-contrib).
| **Nonlinear Models for Regression or Classification** | Random Forest (RF) Classification | Experimental multi-node, multi-GPU version available via Dask integration |
| | Random Forest (RF) Regression | Experimental multi-node, multi-GPU version available via Dask integration |
| | K-Nearest Neighbors (KNN) | Multi-GPU
Uses [Faiss](https://github.com/facebookresearch/faiss) |
+| | Support Vector Machine Classifier (SVC) | |
| **Time Series** | Linear Kalman Filter | |
| | Holt-Winters Exponential Smoothing | |
---
diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu
index 59c8ed9c5c..fa82017d4a 100644
--- a/cpp/src/randomforest/randomforest.cu
+++ b/cpp/src/randomforest/randomforest.cu
@@ -150,10 +150,11 @@ void postprocess_labels(int n_rows, std::vector& labels,
* @param[in] cfg_n_streams: No of parallel CUDA for training forest
*/
void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
- float cfg_rows_sample, int cfg_n_streams) {
+ float cfg_rows_sample, int cfg_seed, int cfg_n_streams) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.rows_sample = cfg_rows_sample;
+ params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (params.n_streams == cfg_n_streams) {
std::cout << "Warning! Max setting Max streams to max openmp threads "
@@ -173,11 +174,12 @@ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
* @param[in] cfg_tree_params: tree parameters
*/
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
- float cfg_rows_sample, int cfg_n_streams,
+ float cfg_rows_sample, int cfg_seed, int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.rows_sample = cfg_rows_sample;
+ params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees;
set_tree_params(params.tree_params); // use input tree params
@@ -462,15 +464,16 @@ RF_metrics score(const cumlHandle& user_handle,
RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_rows_per_node,
bool bootstrap_features, bool bootstrap, int n_trees,
- float rows_sample, CRITERION split_criterion,
- bool quantile_per_tree, int cfg_n_streams) {
+ float rows_sample, int seed,
+ CRITERION split_criterion, bool quantile_per_tree,
+ int cfg_n_streams) {
DecisionTree::DecisionTreeParams tree_params;
DecisionTree::set_tree_params(
tree_params, max_depth, max_leaves, max_features, n_bins, split_algo,
min_rows_per_node, bootstrap_features, split_criterion, quantile_per_tree);
RF_params rf_params;
- set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, cfg_n_streams,
- tree_params);
+ set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, seed,
+ cfg_n_streams, tree_params);
return rf_params;
}
diff --git a/cpp/src/randomforest/randomforest.hpp b/cpp/src/randomforest/randomforest.hpp
index 75f78e6340..277aee5172 100644
--- a/cpp/src/randomforest/randomforest.hpp
+++ b/cpp/src/randomforest/randomforest.hpp
@@ -65,6 +65,10 @@ struct RF_params {
/**
* Decision tree training hyper parameter struct.
*/
+ /**
+ * random seed
+ */
+ int seed;
/**
* Number of concurrent GPU streams for parallel tree building.
* Each stream is independently managed by CPU thread.
@@ -76,9 +80,9 @@ struct RF_params {
void set_rf_params(RF_params& params, int cfg_n_trees = 1,
bool cfg_bootstrap = true, float cfg_rows_sample = 1.0f,
- int cfg_n_streams = 8);
+ int cfg_seed = -1, int cfg_n_streams = 8);
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
- float cfg_rows_sample, int cfg_n_streams,
+ float cfg_rows_sample, int cfg_seed, int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params);
void validity_check(const RF_params rf_params);
void print(const RF_params rf_params);
@@ -154,8 +158,9 @@ RF_metrics score(const cumlHandle& user_handle,
RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_rows_per_node,
bool bootstrap_features, bool bootstrap, int n_trees,
- float rows_sample, CRITERION split_criterion,
- bool quantile_per_tree, int cfg_n_streams);
+ float rows_sample, int seed,
+ CRITERION split_criterion, bool quantile_per_tree,
+ int cfg_n_streams);
// ----------------------------- Regression ----------------------------------- //
diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh
index a1de148a10..471e6da754 100644
--- a/cpp/src/randomforest/randomforest_impl.cuh
+++ b/cpp/src/randomforest/randomforest_impl.cuh
@@ -72,7 +72,10 @@ void rf::prepare_fit_per_tree(
int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows,
const int num_sms, const cudaStream_t stream,
const std::shared_ptr device_allocator) {
- srand(tree_id * 1000);
+ int rs = tree_id;
+ if (rf_params.seed > -1) rs = rf_params.seed + tree_id;
+
+ srand(rs * 1000);
if (rf_params.bootstrap) {
random_uniformInt(tree_id, selected_rows, n_sampled_rows, n_rows, num_sms,
stream);
@@ -221,10 +224,9 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input,
unsigned int* rowids;
rowids = selected_rows[stream_id]->data();
- this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids,
- tempmem[stream_id]->num_sms,
- tempmem[stream_id]->stream,
- handle.getDeviceAllocator());
+ this->prepare_fit_per_tree(
+ i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms,
+ tempmem[stream_id]->stream, handle.getDeviceAllocator());
/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
@@ -236,8 +238,7 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input,
*/
DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]);
tree_ptr->treeid = i;
- trees[i].fit(handle.getDeviceAllocator(),
- handle.getHostAllocator(),
+ trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(),
tempmem[stream_id]->stream, input, n_cols, n_rows, labels,
rowids, n_sampled_rows, n_unique_labels, tree_ptr,
this->rf_params.tree_params, tempmem[stream_id]);
@@ -485,10 +486,9 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input,
for (int i = 0; i < this->rf_params.n_trees; i++) {
int stream_id = omp_get_thread_num();
unsigned int* rowids = selected_rows[stream_id]->data();
- this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids,
- tempmem[stream_id]->num_sms,
- tempmem[stream_id]->stream,
- handle.getDeviceAllocator());
+ this->prepare_fit_per_tree(
+ i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms,
+ tempmem[stream_id]->stream, handle.getDeviceAllocator());
/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
@@ -499,8 +499,7 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input,
*/
DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]);
tree_ptr->treeid = i;
- trees[i].fit(handle.getDeviceAllocator(),
- handle.getHostAllocator(),
+ trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(),
tempmem[stream_id]->stream, input, n_cols, n_rows, labels,
rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params,
tempmem[stream_id]);
diff --git a/cpp/src_prims/linalg/binary_op.h b/cpp/src_prims/linalg/binary_op.h
index 9f799e076c..a7ad986b7a 100644
--- a/cpp/src_prims/linalg/binary_op.h
+++ b/cpp/src_prims/linalg/binary_op.h
@@ -22,37 +22,41 @@
namespace MLCommon {
namespace LinAlg {
-template
-__global__ void binaryOpKernel(math_t *out, const math_t *in1,
- const math_t *in2, IdxType len, Lambda op) {
- typedef TxN_t VecType;
- VecType a, b;
+template
+__global__ void binaryOpKernel(OutType *out, const InType *in1,
+ const InType *in2, IdxType len, Lambda op) {
+ typedef TxN_t InVecType;
+ typedef TxN_t OutVecType;
+ InVecType a, b;
+ OutVecType c;
IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * blockDim.x);
- idx *= VecType::Ratio;
+ idx *= InVecType::Ratio;
if (idx >= len) return;
a.load(in1, idx);
b.load(in2, idx);
#pragma unroll
- for (int i = 0; i < VecType::Ratio; ++i) {
- a.val.data[i] = op(a.val.data[i], b.val.data[i]);
+ for (int i = 0; i < InVecType::Ratio; ++i) {
+ c.val.data[i] = op(a.val.data[i], b.val.data[i]);
}
- a.store(out, idx);
+ c.store(out, idx);
}
-template
-void binaryOpImpl(math_t *out, const math_t *in1, const math_t *in2,
+template
+void binaryOpImpl(OutType *out, const InType *in1, const InType *in2,
IdxType len, Lambda op, cudaStream_t stream) {
- const IdxType nblks = ceildiv(veclen_ ? len / veclen_ : len, (IdxType)TPB);
- binaryOpKernel
+ const IdxType nblks = ceildiv(VecLen ? len / VecLen : len, (IdxType)TPB);
+ binaryOpKernel
<<>>(out, in1, in2, len, op);
CUDA_CHECK(cudaPeekAtLastError());
}
/**
* @brief perform element-wise binary operation on the input arrays
- * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam InType input data-type
* @tparam Lambda the device-lambda performing the actual operation
+ * @tparam OutType output data-type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads-per-block in the final kernel launched
* @param out the output array
@@ -61,30 +65,34 @@ void binaryOpImpl(math_t *out, const math_t *in1, const math_t *in2,
* @param len number of elements in the input array
* @param op the device-lambda
* @param stream cuda stream where to launch work
+ * @note Lambda must be a functor with the following signature:
+ * `OutType func(const InType& val1, const InType& val2);`
*/
-template
-void binaryOp(math_t *out, const math_t *in1, const math_t *in2, IdxType len,
+template
+void binaryOp(OutType *out, const InType *in1, const InType *in2, IdxType len,
Lambda op, cudaStream_t stream) {
- size_t bytes = len * sizeof(math_t);
- if (16 / sizeof(math_t) && bytes % 16 == 0) {
- binaryOpImpl(
+ constexpr auto maxSize =
+ sizeof(InType) > sizeof(OutType) ? sizeof(InType) : sizeof(OutType);
+ size_t bytes = len * maxSize;
+ if (16 / maxSize && bytes % 16 == 0) {
+ binaryOpImpl(
out, in1, in2, len, op, stream);
- } else if (8 / sizeof(math_t) && bytes % 8 == 0) {
- binaryOpImpl(
+ } else if (8 / maxSize && bytes % 8 == 0) {
+ binaryOpImpl(
out, in1, in2, len, op, stream);
- } else if (4 / sizeof(math_t) && bytes % 4 == 0) {
- binaryOpImpl(
+ } else if (4 / maxSize && bytes % 4 == 0) {
+ binaryOpImpl(
out, in1, in2, len, op, stream);
- } else if (2 / sizeof(math_t) && bytes % 2 == 0) {
- binaryOpImpl(
+ } else if (2 / maxSize && bytes % 2 == 0) {
+ binaryOpImpl(
out, in1, in2, len, op, stream);
- } else if (1 / sizeof(math_t)) {
- binaryOpImpl(
+ } else if (1 / maxSize) {
+ binaryOpImpl(
out, in1, in2, len, op, stream);
} else {
- binaryOpImpl(out, in1, in2, len, op,
- stream);
+ binaryOpImpl(out, in1, in2, len,
+ op, stream);
}
}
diff --git a/cpp/src_prims/linalg/unary_op.h b/cpp/src_prims/linalg/unary_op.h
index 24860e6fae..35b1fdd9cd 100644
--- a/cpp/src_prims/linalg/unary_op.h
+++ b/cpp/src_prims/linalg/unary_op.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018, NVIDIA CORPORATION.
+ * Copyright (c) 2018-2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -22,36 +22,40 @@
namespace MLCommon {
namespace LinAlg {
-template
-__global__ void unaryOpKernel(math_t *out, const math_t *in, IdxType len,
+template
+__global__ void unaryOpKernel(OutType *out, const InType *in, IdxType len,
Lambda op) {
- typedef TxN_t VecType;
- VecType a;
+ typedef TxN_t InVecType;
+ typedef TxN_t OutVecType;
+ InVecType a;
+ OutVecType b;
IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * blockDim.x);
- idx *= VecType::Ratio;
+ idx *= InVecType::Ratio;
if (idx >= len) return;
a.load(in, idx);
#pragma unroll
- for (int i = 0; i < VecType::Ratio; ++i) {
- a.val.data[i] = op(a.val.data[i]);
+ for (int i = 0; i < InVecType::Ratio; ++i) {
+ b.val.data[i] = op(a.val.data[i]);
}
- a.store(out, idx);
+ b.store(out, idx);
}
-template
-void unaryOpImpl(math_t *out, const math_t *in, IdxType len, Lambda op,
+template
+void unaryOpImpl(OutType *out, const InType *in, IdxType len, Lambda op,
cudaStream_t stream) {
- const IdxType nblks = ceildiv(veclen_ ? len / veclen_ : len, (IdxType)TPB);
- unaryOpKernel
+ const IdxType nblks = ceildiv(VecLen ? len / VecLen : len, (IdxType)TPB);
+ unaryOpKernel
<<>>(out, in, len, op);
CUDA_CHECK(cudaPeekAtLastError());
}
/**
* @brief perform element-wise unary operation in the input array
- * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam InType input data-type
* @tparam Lambda the device-lambda performing the actual operation
+ * @tparam OutType output data-type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads-per-block in the final kernel launched
* @param out the output array
@@ -59,36 +63,41 @@ void unaryOpImpl(math_t *out, const math_t *in, IdxType len, Lambda op,
* @param len number of elements in the input array
* @param op the device-lambda
* @param stream cuda stream where to launch work
+ * @note Lambda must be a functor with the following signature:
+ * `OutType func(const InType& val);`
*/
-template
-void unaryOp(math_t *out, const math_t *in, IdxType len, Lambda op,
+template
+void unaryOp(OutType *out, const InType *in, IdxType len, Lambda op,
cudaStream_t stream) {
if (len <= 0) return; //silently skip in case of 0 length input
- size_t bytes = len * sizeof(math_t);
+ constexpr auto maxSize =
+ sizeof(InType) >= sizeof(OutType) ? sizeof(InType) : sizeof(OutType);
+ size_t bytes = len * maxSize;
uint64_t inAddr = uint64_t(in);
uint64_t outAddr = uint64_t(out);
- if (16 / sizeof(math_t) && bytes % 16 == 0 && inAddr % 16 == 0 &&
+ if (16 / maxSize && bytes % 16 == 0 && inAddr % 16 == 0 &&
outAddr % 16 == 0) {
- unaryOpImpl(out, in, len,
- op, stream);
- } else if (8 / sizeof(math_t) && bytes % 8 == 0 && inAddr % 8 == 0 &&
+ unaryOpImpl(
+ out, in, len, op, stream);
+ } else if (8 / maxSize && bytes % 8 == 0 && inAddr % 8 == 0 &&
outAddr % 8 == 0) {
- unaryOpImpl(out, in, len,
- op, stream);
- } else if (4 / sizeof(math_t) && bytes % 4 == 0 && inAddr % 4 == 0 &&
+ unaryOpImpl(
+ out, in, len, op, stream);
+ } else if (4 / maxSize && bytes % 4 == 0 && inAddr % 4 == 0 &&
outAddr % 4 == 0) {
- unaryOpImpl(out, in, len,
- op, stream);
- } else if (2 / sizeof(math_t) && bytes % 2 == 0 && inAddr % 2 == 0 &&
+ unaryOpImpl(
+ out, in, len, op, stream);
+ } else if (2 / maxSize && bytes % 2 == 0 && inAddr % 2 == 0 &&
outAddr % 2 == 0) {
- unaryOpImpl(out, in, len,
- op, stream);
- } else if (1 / sizeof(math_t)) {
- unaryOpImpl(out, in, len,
- op, stream);
+ unaryOpImpl(
+ out, in, len, op, stream);
+ } else if (1 / maxSize) {
+ unaryOpImpl(
+ out, in, len, op, stream);
} else {
- unaryOpImpl(out, in, len, op, stream);
+ unaryOpImpl(out, in, len, op,
+ stream);
}
}
diff --git a/cpp/test/prims/binary_op.cu b/cpp/test/prims/binary_op.cu
index 917df565d6..c0bac8b438 100644
--- a/cpp/test/prims/binary_op.cu
+++ b/cpp/test/prims/binary_op.cu
@@ -26,19 +26,21 @@ namespace LinAlg {
// Or else, we get the following compilation error
// for an extended __device__ lambda cannot have private or protected access
// within its class
-template
-void binaryOpLaunch(T *out, const T *in1, const T *in2, IdxType len,
- cudaStream_t stream) {
+template
+void binaryOpLaunch(OutType *out, const InType *in1, const InType *in2,
+ IdxType len, cudaStream_t stream) {
binaryOp(
- out, in1, in2, len, [] __device__(T a, T b) { return a + b; }, stream);
+ out, in1, in2, len, [] __device__(InType a, InType b) { return a + b; },
+ stream);
}
-template
+template
class BinaryOpTest
- : public ::testing::TestWithParam> {
+ : public ::testing::TestWithParam> {
protected:
void SetUp() override {
- params = ::testing::TestWithParam>::GetParam();
+ params = ::testing::TestWithParam<
+ BinaryOpInputs>::GetParam();
Random::Rng r(params.seed);
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
@@ -47,8 +49,8 @@ class BinaryOpTest
allocate(in2, len);
allocate(out_ref, len);
allocate(out, len);
- r.uniform(in1, len, T(-1.0), T(1.0), stream);
- r.uniform(in2, len, T(-1.0), T(1.0), stream);
+ r.uniform(in1, len, InType(-1.0), InType(1.0), stream);
+ r.uniform(in2, len, InType(-1.0), InType(1.0), stream);
naiveAdd(out_ref, in1, in2, len);
binaryOpLaunch(out, in1, in2, len, stream);
CUDA_CHECK(cudaStreamDestroy(stream));
@@ -62,8 +64,9 @@ class BinaryOpTest
}
protected:
- BinaryOpInputs params;
- T *in1, *in2, *out_ref, *out;
+ BinaryOpInputs params;
+ InType *in1, *in2;
+ OutType *out_ref, *out;
};
const std::vector> inputsf_i32 = {
@@ -86,6 +89,16 @@ TEST_P(BinaryOpTestF_i64, Result) {
INSTANTIATE_TEST_CASE_P(BinaryOpTests, BinaryOpTestF_i64,
::testing::ValuesIn(inputsf_i64));
+const std::vector> inputsf_i32_d = {
+ {0.000001f, 1024 * 1024, 1234ULL}};
+typedef BinaryOpTest BinaryOpTestF_i32_D;
+TEST_P(BinaryOpTestF_i32_D, Result) {
+ ASSERT_TRUE(devArrMatch(out_ref, out, params.len,
+ CompareApprox(params.tolerance)));
+}
+INSTANTIATE_TEST_CASE_P(BinaryOpTests, BinaryOpTestF_i32_D,
+ ::testing::ValuesIn(inputsf_i32_d));
+
const std::vector> inputsd_i32 = {
{0.00000001, 1024 * 1024, 1234ULL}};
typedef BinaryOpTest BinaryOpTestD_i32;
diff --git a/cpp/test/prims/binary_op.h b/cpp/test/prims/binary_op.h
index 30d0b3f42b..b1d25d81e2 100644
--- a/cpp/test/prims/binary_op.h
+++ b/cpp/test/prims/binary_op.h
@@ -22,33 +22,33 @@
namespace MLCommon {
namespace LinAlg {
-template
-__global__ void naiveAddKernel(Type *out, const Type *in1, const Type *in2,
- IdxType len) {
+template
+__global__ void naiveAddKernel(OutType *out, const InType *in1,
+ const InType *in2, IdxType len) {
IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * (IdxType)blockDim.x);
if (idx < len) {
- out[idx] = in1[idx] + in2[idx];
+ out[idx] = static_cast(in1[idx] + in2[idx]);
}
}
-template
-void naiveAdd(Type *out, const Type *in1, const Type *in2, IdxType len) {
+template
+void naiveAdd(OutType *out, const InType *in1, const InType *in2, IdxType len) {
static const IdxType TPB = 64;
IdxType nblks = ceildiv(len, TPB);
- naiveAddKernel<<>>(out, in1, in2, len);
+ naiveAddKernel<<>>(out, in1, in2, len);
CUDA_CHECK(cudaPeekAtLastError());
}
-template
+template
struct BinaryOpInputs {
- T tolerance;
+ InType tolerance;
IdxType len;
unsigned long long int seed;
};
-template
+template
::std::ostream &operator<<(::std::ostream &os,
- const BinaryOpInputs &dims) {
+ const BinaryOpInputs &d) {
return os;
}
diff --git a/cpp/test/prims/unary_op.cu b/cpp/test/prims/unary_op.cu
index 62d925fbc7..ca9d0d25c9 100644
--- a/cpp/test/prims/unary_op.cu
+++ b/cpp/test/prims/unary_op.cu
@@ -26,29 +26,31 @@ namespace LinAlg {
// Or else, we get the following compilation error
// for an extended __device__ lambda cannot have private or protected access
// within its class
-template
-void unaryOpLaunch(T *out, const T *in, T scalar, IdxType len,
+template
+void unaryOpLaunch(OutType *out, const InType *in, InType scalar, IdxType len,
cudaStream_t stream) {
- unaryOp(
- out, in, len, [scalar] __device__(T in) { return in * scalar; }, stream);
+ auto op = [scalar] __device__(InType in) {
+ return static_cast(in * scalar);
+ };
+ unaryOp(out, in, len, op, stream);
}
-template
-class UnaryOpTest : public ::testing::TestWithParam> {
+template
+class UnaryOpTest
+ : public ::testing::TestWithParam> {
protected:
void SetUp() override {
- params = ::testing::TestWithParam>::GetParam();
+ params = ::testing::TestWithParam<
+ UnaryOpInputs>::GetParam();
Random::Rng r(params.seed);
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
-
auto len = params.len;
auto scalar = params.scalar;
-
allocate(in, len);
allocate(out_ref, len);
allocate(out, len);
- r.uniform(in, len, T(-1.0), T(1.0), stream);
+ r.uniform(in, len, InType(-1.0), InType(1.0), stream);
naiveScale(out_ref, in, scalar, len, stream);
unaryOpLaunch(out, in, scalar, len, stream);
CUDA_CHECK(cudaStreamDestroy(stream));
@@ -61,8 +63,9 @@ class UnaryOpTest : public ::testing::TestWithParam> {
}
protected:
- UnaryOpInputs params;
- T *in, *out_ref, *out;
+ UnaryOpInputs params;
+ InType *in;
+ OutType *out_ref, *out;
};
const std::vector> inputsf_i32 = {
@@ -85,6 +88,16 @@ TEST_P(UnaryOpTestF_i64, Result) {
INSTANTIATE_TEST_CASE_P(UnaryOpTests, UnaryOpTestF_i64,
::testing::ValuesIn(inputsf_i64));
+const std::vector> inputsf_i32_d = {
+ {0.000001f, 1024 * 1024, 2.f, 1234ULL}};
+typedef UnaryOpTest UnaryOpTestF_i32_D;
+TEST_P(UnaryOpTestF_i32_D, Result) {
+ ASSERT_TRUE(devArrMatch(out_ref, out, params.len,
+ CompareApprox(params.tolerance)));
+}
+INSTANTIATE_TEST_CASE_P(UnaryOpTests, UnaryOpTestF_i32_D,
+ ::testing::ValuesIn(inputsf_i32_d));
+
const std::vector> inputsd_i32 = {
{0.00000001, 1024 * 1024, 2.0, 1234ULL}};
typedef UnaryOpTest UnaryOpTestD_i32;
diff --git a/cpp/test/prims/unary_op.h b/cpp/test/prims/unary_op.h
index bc08a56d59..24e5e2ec1c 100644
--- a/cpp/test/prims/unary_op.h
+++ b/cpp/test/prims/unary_op.h
@@ -22,35 +22,36 @@
namespace MLCommon {
namespace LinAlg {
-template
-__global__ void naiveScaleKernel(Type *out, const Type *in, Type scalar,
+template
+__global__ void naiveScaleKernel(OutType *out, const InType *in, InType scalar,
IdxType len) {
IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * (IdxType)blockDim.x);
if (idx < len) {
- out[idx] = scalar * in[idx];
+ out[idx] = static_cast(scalar * in[idx]);
}
}
-template
-void naiveScale(Type *out, const Type *in, Type scalar, int len,
+template
+void naiveScale(OutType *out, const InType *in, InType scalar, int len,
cudaStream_t stream) {
static const int TPB = 64;
int nblks = ceildiv(len, TPB);
- naiveScaleKernel<<>>(out, in, scalar, len);
+ naiveScaleKernel
+ <<>>(out, in, scalar, len);
CUDA_CHECK(cudaPeekAtLastError());
}
-template
+template
struct UnaryOpInputs {
- T tolerance;
+ InType tolerance;
IdxType len;
- T scalar;
+ InType scalar;
unsigned long long int seed;
};
-template
+template
::std::ostream &operator<<(::std::ostream &os,
- const UnaryOpInputs &dims) {
+ const UnaryOpInputs &d) {
return os;
}
diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu
index 512c53c2af..115cbd8e97 100644
--- a/cpp/test/sg/rf_test.cu
+++ b/cpp/test/sg/rf_test.cu
@@ -61,7 +61,7 @@ class RfClassifierTest : public ::testing::TestWithParam> {
params.split_criterion, false);
RF_params rf_params;
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
- params.rows_sample, params.n_streams, tree_params);
+ params.rows_sample, -1, params.n_streams, tree_params);
//print(rf_params);
//--------------------------------------------------------
@@ -161,7 +161,7 @@ class RfRegressorTest : public ::testing::TestWithParam> {
params.split_criterion, false);
RF_params rf_params;
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
- params.rows_sample, params.n_streams, tree_params);
+ params.rows_sample, -1, params.n_streams, tree_params);
//print(rf_params);
//--------------------------------------------------------
diff --git a/cpp/test/sg/rf_treelite_test.cu b/cpp/test/sg/rf_treelite_test.cu
index d772b0c83e..279db9588e 100644
--- a/cpp/test/sg/rf_treelite_test.cu
+++ b/cpp/test/sg/rf_treelite_test.cu
@@ -181,7 +181,7 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam> {
params.min_rows_per_node, params.bootstrap_features,
params.split_criterion, false);
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
- params.rows_sample, params.n_streams, tree_params);
+ params.rows_sample, -1, params.n_streams, tree_params);
// print(rf_params);
handle.reset(new cumlHandle(rf_params.n_streams));
diff --git a/docs/source/api.rst b/docs/source/api.rst
index fce956891b..ce0fb8d750 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -4,8 +4,8 @@ cuML API Reference
-Preprocessing
-==============
+Preprocessing, Metrics, and Utilities
+=====================================
Model Selection and Data Splitting
----------------------------------
@@ -24,6 +24,44 @@ Dataset Generation
.. automethod:: cuml.datasets.make_blobs
+Metrics
+---------
+
+ .. automodule:: cuml.metrics.regression
+ :members:
+
+ .. automodule:: cuml.metrics.accuracy
+ :members:
+
+ .. automodule:: cuml.metrics.trustworthiness
+ :members:
+
+ .. automodule:: cuml.metrics.cluster
+ :members:
+
+Benchmarking
+-------------
+
+ .. automodule:: cuml.benchmark.algorithms
+ :members:
+
+ .. automodule:: cuml.benchmark.runners
+ :members:
+
+ .. automodule:: cuml.benchmark.datagen
+ :members:
+
+
+
+Utilities for I/O and Numba
+---------------------------
+
+ .. automodule:: cuml.utils.input_utils
+ :members:
+
+ .. automodule:: cuml.utils.numba_utils
+ :members:
+
Regression and Classification
=============================
@@ -84,6 +122,12 @@ Quasi-Newton
.. autoclass:: cuml.QN
:members:
+Support Vector Machines
+------------------------
+
+.. autoclass:: cuml.svm.SVC
+ :members:
+
Clustering
==========
diff --git a/python/cuml/benchmark/datagen.py b/python/cuml/benchmark/datagen.py
index 8a881687ad..203d6a6759 100644
--- a/python/cuml/benchmark/datagen.py
+++ b/python/cuml/benchmark/datagen.py
@@ -22,7 +22,7 @@
* n_samples (set to 0 for 'default')
* n_features (set to 0 for 'default')
* random_state
- * .. and optional generator-specific parameters
+ * (and optional generator-specific parameters)
The function should return a 2-tuple (X, y), where X is a Pandas
dataframe and y is a Pandas series. If the generator does not produce
diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx
index 802acdc907..aa1cd53fab 100644
--- a/python/cuml/cluster/kmeans.pyx
+++ b/python/cuml/cluster/kmeans.pyx
@@ -24,7 +24,7 @@ import cudf
import numpy as np
import warnings
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/cluster/kmeans_mg.pyx b/python/cuml/cluster/kmeans_mg.pyx
index d04d448f7a..57785a4a5d 100644
--- a/python/cuml/cluster/kmeans_mg.pyx
+++ b/python/cuml/cluster/kmeans_mg.pyx
@@ -24,7 +24,7 @@ import cudf
import numpy as np
import warnings
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py
index 9e063e18a4..8187af137a 100755
--- a/python/cuml/dask/ensemble/randomforestclassifier.py
+++ b/python/cuml/dask/ensemble/randomforestclassifier.py
@@ -194,6 +194,12 @@ def __init__(
self.n_estimators_per_worker[i] + 1
)
+ seeds = list()
+ seeds.append(0)
+ for i in range(1, len(self.n_estimators_per_worker)):
+ sd = self.n_estimators_per_worker[i-1] + seeds[i-1]
+ seeds.append(sd)
+
key = str(uuid1())
self.rfs = {
worker: c.submit(
@@ -213,6 +219,7 @@ def __init__(
rows_sample,
max_leaves,
quantile_per_tree,
+ seeds[n],
dtype,
key="%s-%s" % (key, n),
workers=[worker],
@@ -243,6 +250,7 @@ def _func_build_rf(
rows_sample,
max_leaves,
quantile_per_tree,
+ seed,
dtype,
):
return cuRFC(
@@ -262,6 +270,7 @@ def _func_build_rf(
max_leaves=max_leaves,
n_streams=n_streams,
quantile_per_tree=quantile_per_tree,
+ seed=seed,
gdf_datatype=dtype,
)
diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py
index b86e1a9269..e12bb75056 100755
--- a/python/cuml/dask/ensemble/randomforestregressor.py
+++ b/python/cuml/dask/ensemble/randomforestregressor.py
@@ -197,6 +197,12 @@ def __init__(
self.n_estimators_per_worker[i] + 1
)
+ seeds = list()
+ seeds.append(0)
+ for i in range(1, len(self.n_estimators_per_worker)):
+ sd = self.n_estimators_per_worker[i-1] + seeds[i-1]
+ seeds.append(sd)
+
key = str(uuid1())
self.rfs = {
worker: c.submit(
@@ -216,6 +222,7 @@ def __init__(
max_leaves,
accuracy_metric,
quantile_per_tree,
+ seeds[n],
key="%s-%s" % (key, n),
workers=[worker],
)
@@ -245,6 +252,7 @@ def _func_build_rf(
max_leaves,
accuracy_metric,
quantile_per_tree,
+ seed,
):
return cuRFR(
@@ -264,6 +272,7 @@ def _func_build_rf(
n_streams=n_streams,
accuracy_metric=accuracy_metric,
quantile_per_tree=quantile_per_tree,
+ seed=seed,
)
@staticmethod
diff --git a/python/cuml/dask/linear_model/linear_regression.py b/python/cuml/dask/linear_model/linear_regression.py
index 1034af5489..a860c00b6f 100644
--- a/python/cuml/dask/linear_model/linear_regression.py
+++ b/python/cuml/dask/linear_model/linear_regression.py
@@ -26,7 +26,7 @@
from dask import delayed
from dask.distributed import wait, default_client
from math import ceil
-from librmm_cffi import librmm as rmm
+import rmm
from toolz import first
from tornado import gen
diff --git a/python/cuml/dask/neighbors/nearest_neighbors.py b/python/cuml/dask/neighbors/nearest_neighbors.py
index b37ee0568c..7128eb2dd1 100644
--- a/python/cuml/dask/neighbors/nearest_neighbors.py
+++ b/python/cuml/dask/neighbors/nearest_neighbors.py
@@ -24,7 +24,7 @@
import random
from cuml.utils import numba_utils
-from librmm_cffi import librmm as rmm
+import rmm
from dask import delayed
from collections import defaultdict
diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx
index 78aab743ad..cd28e929ed 100644
--- a/python/cuml/decomposition/pca.pyx
+++ b/python/cuml/decomposition/pca.pyx
@@ -23,7 +23,7 @@ import ctypes
import cudf
import numpy as np
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx
index eb10bf9565..66e4866048 100644
--- a/python/cuml/decomposition/tsvd.pyx
+++ b/python/cuml/decomposition/tsvd.pyx
@@ -23,7 +23,7 @@ import ctypes
import cudf
import numpy as np
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx
index 04f848a92d..4108a37b8d 100644
--- a/python/cuml/ensemble/randomforestclassifier.pyx
+++ b/python/cuml/ensemble/randomforestclassifier.pyx
@@ -81,6 +81,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
int n_trees
bool bootstrap
float rows_sample
+ int seed
pass
cdef cppclass RandomForestMetaData[T, L]:
@@ -181,6 +182,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
bool,
int,
float,
+ int,
CRITERION,
bool,
int) except +
@@ -302,7 +304,8 @@ class RandomForestClassifier(Base):
min_samples_leaf=None, min_weight_fraction_leaf=None,
max_leaf_nodes=None, min_impurity_decrease=None,
min_impurity_split=None, oob_score=None, n_jobs=None,
- random_state=None, warm_start=None, class_weight=None):
+ random_state=None, warm_start=None, class_weight=None,
+ seed=-1):
sklearn_params = {"criterion": criterion,
"min_samples_leaf": min_samples_leaf,
@@ -353,6 +356,7 @@ class RandomForestClassifier(Base):
self.quantile_per_tree = quantile_per_tree
self.n_cols = None
self.n_streams = n_streams
+ self.seed = seed
cdef RandomForestMetaData[float, int] *rf_forest = \
new RandomForestMetaData[float, int]()
@@ -497,6 +501,7 @@ class RandomForestClassifier(Base):
self.bootstrap,
self.n_estimators,
self.rows_sample,
+ self.seed,
self.split_criterion,
self.quantile_per_tree,
self.n_streams)
@@ -608,6 +613,7 @@ class RandomForestClassifier(Base):
num_classes=2):
"""
Predicts the labels for X.
+
Parameters
----------
X : array-like (device or host) shape = (n_samples, n_features)
@@ -639,9 +645,10 @@ class RandomForestClassifier(Base):
It is applied if output_class == True, else it is ignored
num_classes : integer
number of different classes present in the dataset
+
Returns
----------
- y: NumPy
+ y : NumPy
Dense vector (int) of shape (n_samples, 1)
"""
if self.dtype == np.float64:
@@ -662,15 +669,17 @@ class RandomForestClassifier(Base):
def _predict_get_all(self, X):
"""
Predicts the labels for X.
+
Parameters
----------
X : array-like (device or host) shape = (n_samples, n_features)
Dense matrix (floats or doubles) of shape (n_samples, n_features).
Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
ndarray, cuda array interface compliant array like CuPy
+
Returns
----------
- y: NumPy
+ y : NumPy
Dense vector (int) of shape (n_samples, 1)
"""
cdef uintptr_t X_ptr
@@ -724,17 +733,20 @@ class RandomForestClassifier(Base):
def score(self, X, y):
"""
Calculates the accuracy metric score of the model for X.
+
Parameters
----------
X : array-like (device or host) shape = (n_samples, n_features)
Dense matrix (floats or doubles) of shape (n_samples, n_features).
Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
ndarray, cuda array interface compliant array like CuPy
- y: NumPy
+ y : NumPy
Dense vector (int) of shape (n_samples, 1)
+
Returns
- ----------
- accuracy of the model
+ -------
+ float
+ Accuracy of the model [0.0 - 1.0]
"""
cdef uintptr_t X_ptr, y_ptr
X_m, X_ptr, n_rows, n_cols, _ = \
@@ -795,6 +807,7 @@ class RandomForestClassifier(Base):
"""
Returns the value of all parameters
required to configure this estimator as a dictionary.
+
Parameters
-----------
deep : boolean (default = True)
@@ -811,6 +824,7 @@ class RandomForestClassifier(Base):
Sets the value of parameters required to
configure this estimator, it functions similar to
the sklearn set_params.
+
Parameters
-----------
params : dict of new params
diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx
index 0b9009b510..3c675c74ef 100644
--- a/python/cuml/ensemble/randomforestregressor.pyx
+++ b/python/cuml/ensemble/randomforestregressor.pyx
@@ -80,6 +80,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
int n_trees
bool bootstrap
float rows_sample
+ int seed
pass
cdef cppclass RandomForestMetaData[T, L]:
@@ -162,6 +163,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML":
bool,
int,
float,
+ int,
CRITERION,
bool,
int) except +
@@ -286,7 +288,7 @@ class RandomForestRegressor(Base):
max_leaf_nodes=None, min_impurity_decrease=None,
min_impurity_split=None, oob_score=None,
random_state=None, warm_start=None, class_weight=None,
- quantile_per_tree=False, criterion=None):
+ quantile_per_tree=False, criterion=None, seed=-1):
sklearn_params = {"criterion": criterion,
"min_samples_leaf": min_samples_leaf,
@@ -337,6 +339,7 @@ class RandomForestRegressor(Base):
self.accuracy_metric = accuracy_metric
self.quantile_per_tree = quantile_per_tree
self.n_streams = n_streams
+ self.seed = seed
cdef RandomForestMetaData[float, float] *rf_forest = \
new RandomForestMetaData[float, float]()
@@ -461,6 +464,7 @@ class RandomForestRegressor(Base):
self.bootstrap,
self.n_estimators,
self.rows_sample,
+ self.seed,
self.split_criterion,
self.quantile_per_tree,
self.n_streams)
diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx
index 58ecb9645a..d55983abda 100644
--- a/python/cuml/fil/fil.pyx
+++ b/python/cuml/fil/fil.pyx
@@ -26,7 +26,7 @@ import math
import numpy as np
import warnings
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/filter/kalman_filter.pyx b/python/cuml/filter/kalman_filter.pyx
index 7308fa14ab..22bf510a33 100644
--- a/python/cuml/filter/kalman_filter.pyx
+++ b/python/cuml/filter/kalman_filter.pyx
@@ -24,7 +24,7 @@ import numpy as np
from numba import cuda
from cuml.utils import numba_utils
-from librmm_cffi import librmm as rmm
+import rmm
from libc.stdint cimport uintptr_t
from libc.stdlib cimport calloc, malloc, free
diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx
index 9ee124b97f..cb27bbf9af 100644
--- a/python/cuml/manifold/t_sne.pyx
+++ b/python/cuml/manifold/t_sne.pyx
@@ -32,7 +32,7 @@ from cuml.common.base import Base
from cuml.common.handle cimport cumlHandle
from cuml.utils import input_to_dev_array as to_cuda
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx
index 1f120c0685..25a42669a5 100644
--- a/python/cuml/manifold/umap.pyx
+++ b/python/cuml/manifold/umap.pyx
@@ -31,7 +31,7 @@ from cuml.common.handle cimport cumlHandle
from cuml.utils import get_cudf_column_ptr, get_dev_array_ptr, \
input_to_dev_array, zeros, row_matrix
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/metrics/accuracy.pyx b/python/cuml/metrics/accuracy.pyx
index dfebe1910f..5ce62bb591 100644
--- a/python/cuml/metrics/accuracy.pyx
+++ b/python/cuml/metrics/accuracy.pyx
@@ -42,13 +42,16 @@ def accuracy_score(ground_truth, predictions, handle=None):
Parameters
----------
- handle : cuml.Handle
- prediction : The lables predicted by the model
- for the test dataset
- ground_truth : The ground truth labels of the test dataset
+ handle : cuml.Handle
+ prediction : NumPy ndarray or Numba device
+ The lables predicted by the model for the test dataset
+ ground_truth : NumPy ndarray, Numba device
+ The ground truth labels of the test dataset
+
Returns
-------
- The accuracy of the model used for prediction
+ float
+ The accuracy of the model used for prediction
"""
handle = cuml.common.handle.Handle() \
if handle is None else handle
diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx
index 7e21a1aa87..49812e76f4 100644
--- a/python/cuml/neighbors/nearest_neighbors.pyx
+++ b/python/cuml/neighbors/nearest_neighbors.pyx
@@ -35,14 +35,14 @@ from cython.operator cimport dereference as deref
from libcpp cimport bool
from libcpp.memory cimport shared_ptr
-from librmm_cffi import librmm as rmm
+import rmm
from libc.stdlib cimport malloc, free
from libc.stdint cimport uintptr_t
from libc.stdlib cimport calloc, malloc, free
from numba import cuda
-from librmm_cffi import librmm as rmm
+import rmm
cimport cuml.common.handle
cimport cuml.common.cuda
diff --git a/python/cuml/preprocessing/LabelEncoder.py b/python/cuml/preprocessing/LabelEncoder.py
index a62765b4c5..f5f28a0f27 100644
--- a/python/cuml/preprocessing/LabelEncoder.py
+++ b/python/cuml/preprocessing/LabelEncoder.py
@@ -16,7 +16,7 @@
import cudf
import nvcategory
-from librmm_cffi import librmm
+import rmm
import numpy as np
@@ -194,7 +194,7 @@ def fit_transform(self, y: cudf.Series) -> cudf.Series:
self._cats = nvcategory.from_strings(y.data)
self._fitted = True
- arr: librmm.device_array = librmm.device_array(
+ arr: rmm.device_array = rmm.device_array(
y.data.size(), dtype=np.int32
)
self._cats.values(devptr=arr.device_ctypes_pointer.value)
diff --git a/python/cuml/random_projection/random_projection.pyx b/python/cuml/random_projection/random_projection.pyx
index 7897eecc61..efa6ba8193 100644
--- a/python/cuml/random_projection/random_projection.pyx
+++ b/python/cuml/random_projection/random_projection.pyx
@@ -22,7 +22,7 @@
import cudf
import numpy as np
-from librmm_cffi import librmm as rmm
+import rmm
from libc.stdint cimport uintptr_t
from libcpp cimport bool
diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx
index f783590dfe..393ab2b379 100644
--- a/python/cuml/solvers/qn.pyx
+++ b/python/cuml/solvers/qn.pyx
@@ -23,7 +23,7 @@ import cudf
import numpy as np
import warnings
-from librmm_cffi import librmm as rmm
+import rmm
from libcpp cimport bool
from libc.stdint cimport uintptr_t
diff --git a/python/cuml/svm/svm.pyx b/python/cuml/svm/svm.pyx
index 141e8a5002..8ebf541f49 100644
--- a/python/cuml/svm/svm.pyx
+++ b/python/cuml/svm/svm.pyx
@@ -442,9 +442,10 @@ class SVC(Base):
Dense matrix (floats or doubles) of shape (n_samples, n_features).
Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
ndarray, cuda array interface compliant array like CuPy
+
Returns
- ----------
- y: cuDF Series
+ -------
+ y : cuDF Series
Dense vector (floats or doubles) of shape (n_samples, 1)
"""
diff --git a/python/cuml/utils/input_utils.py b/python/cuml/utils/input_utils.py
index 6170ab4223..62987d7bb8 100644
--- a/python/cuml/utils/input_utils.py
+++ b/python/cuml/utils/input_utils.py
@@ -26,7 +26,7 @@
from collections.abc import Collection
from numba import cuda
-from librmm_cffi import librmm as rmm
+import rmm
inp_array = namedtuple('inp_array', 'array pointer n_rows n_cols dtype')
@@ -71,14 +71,20 @@ def input_to_dev_array(X, order='F', deepcopy=False,
check_cols=False, check_rows=False,
fail_on_order=False):
"""
- Convert input X to device array suitable for C++ methods
+ Convert input X to device array suitable for C++ methods.
+
Acceptable input formats:
+
* cuDF Dataframe - returns a deep copy always
+
* cuDF Series - returns by reference or a deep copy depending on
`deepcopy`
+
* Numpy array - returns a copy in device always
+
* cuda array interface compliant array (like Cupy) - returns a
- reference unless deepcopy=True
+ reference unless `deepcopy`=True
+
* numba device array - returns a reference unless deepcopy=True
Parameters
@@ -309,13 +315,18 @@ def input_to_host_array(X, order='F', deepcopy=False,
"""
Convert input X to host array (NumPy) suitable for C++ methods that accept
host arrays.
+
Acceptable input formats:
+
* Numpy array - returns a pointer to the original input
+
* cuDF Dataframe - returns a deep copy always
- * cuDF Series - returns by reference or a deep copy depending on
- `deepcopy`
- * cuda array interface compliant array (like Cupy) - returns a
+
+ * cuDF Series - returns by reference or a deep copy depending on `deepcopy`
+
+ * cuda array interface compliant array (like Cupy) - returns a \
reference unless deepcopy=True
+
* numba device array - returns a reference unless deepcopy=True
Parameters
diff --git a/python/cuml/utils/numba_utils.py b/python/cuml/utils/numba_utils.py
index 288063b165..dc6b2d8c00 100644
--- a/python/cuml/utils/numba_utils.py
+++ b/python/cuml/utils/numba_utils.py
@@ -18,7 +18,7 @@
from numba import cuda
from numba.cuda.cudadrv.driver import driver
-from librmm_cffi import librmm as rmm
+import rmm
import numpy as np