diff --git a/CHANGELOG.md b/CHANGELOG.md index 51c3016454..a68c04ece0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,10 +22,12 @@ - PR #1086: Ensure RegressorMixin scorer uses device arrays - PR #1108: input_to_host_array function in input_utils for input processing to host arrays - PR #1114: K-means: Exposing useful params, removing unused params, proxying params in Dask +- PR #1142: prims: expose separate InType and OutType for unaryOp and binaryOp - PR #1115: Moving dask_make_blobs to cuml.dask.datasets. Adding conversion to dask.DataFrame - PR #1136: CUDA 10.1 CI updates - PR #1165: Adding except + in all remaining cython - PR #1173: Docs: Barnes Hut TSNE documentation +- PR #1176: Use new RMM API based on Cython ## Bug Fixes @@ -44,6 +46,7 @@ - PR #1106: Pinning Distributed version to match Dask for consistent CI results - PR #1116: TSNE CUDA 10.1 Bug Fixes - PR #1132: DBSCAN Batching Bug Fix +- PR #1162: DASK RF random seed bug fix - PR #1164: Fix check_dtype arg handling for input_to_dev_array # cuML 0.9.0 (21 Aug 2019) @@ -105,6 +108,7 @@ - PR #978: Update README for 0.9 - PR #1009: Fix references to notebooks-contrib - PR #1015: Ability to control the number of internal streams in cumlHandle_impl via cumlHandle +- PR #1175: Add more modules to docs ToC ## Bug Fixes diff --git a/README.md b/README.md index ec7c281282..15523518f9 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ repo](https://github.com/rapidsai/notebooks-contrib). | **Nonlinear Models for Regression or Classification** | Random Forest (RF) Classification | Experimental multi-node, multi-GPU version available via Dask integration | | | Random Forest (RF) Regression | Experimental multi-node, multi-GPU version available via Dask integration | | | K-Nearest Neighbors (KNN) | Multi-GPU
Uses [Faiss](https://github.com/facebookresearch/faiss) | +| | Support Vector Machine Classifier (SVC) | | | **Time Series** | Linear Kalman Filter | | | | Holt-Winters Exponential Smoothing | | --- diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index 59c8ed9c5c..fa82017d4a 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -150,10 +150,11 @@ void postprocess_labels(int n_rows, std::vector& labels, * @param[in] cfg_n_streams: No of parallel CUDA for training forest */ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_rows_sample, int cfg_n_streams) { + float cfg_rows_sample, int cfg_seed, int cfg_n_streams) { params.n_trees = cfg_n_trees; params.bootstrap = cfg_bootstrap; params.rows_sample = cfg_rows_sample; + params.seed = cfg_seed; params.n_streams = min(cfg_n_streams, omp_get_max_threads()); if (params.n_streams == cfg_n_streams) { std::cout << "Warning! Max setting Max streams to max openmp threads " @@ -173,11 +174,12 @@ void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, * @param[in] cfg_tree_params: tree parameters */ void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_rows_sample, int cfg_n_streams, + float cfg_rows_sample, int cfg_seed, int cfg_n_streams, DecisionTree::DecisionTreeParams cfg_tree_params) { params.n_trees = cfg_n_trees; params.bootstrap = cfg_bootstrap; params.rows_sample = cfg_rows_sample; + params.seed = cfg_seed; params.n_streams = min(cfg_n_streams, omp_get_max_threads()); if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees; set_tree_params(params.tree_params); // use input tree params @@ -462,15 +464,16 @@ RF_metrics score(const cumlHandle& user_handle, RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, int n_bins, int split_algo, int min_rows_per_node, bool bootstrap_features, bool bootstrap, int n_trees, - float rows_sample, CRITERION split_criterion, - bool quantile_per_tree, int cfg_n_streams) { + float rows_sample, int seed, + CRITERION split_criterion, bool quantile_per_tree, + int cfg_n_streams) { DecisionTree::DecisionTreeParams tree_params; DecisionTree::set_tree_params( tree_params, max_depth, max_leaves, max_features, n_bins, split_algo, min_rows_per_node, bootstrap_features, split_criterion, quantile_per_tree); RF_params rf_params; - set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, cfg_n_streams, - tree_params); + set_all_rf_params(rf_params, n_trees, bootstrap, rows_sample, seed, + cfg_n_streams, tree_params); return rf_params; } diff --git a/cpp/src/randomforest/randomforest.hpp b/cpp/src/randomforest/randomforest.hpp index 75f78e6340..277aee5172 100644 --- a/cpp/src/randomforest/randomforest.hpp +++ b/cpp/src/randomforest/randomforest.hpp @@ -65,6 +65,10 @@ struct RF_params { /** * Decision tree training hyper parameter struct. */ + /** + * random seed + */ + int seed; /** * Number of concurrent GPU streams for parallel tree building. * Each stream is independently managed by CPU thread. @@ -76,9 +80,9 @@ struct RF_params { void set_rf_params(RF_params& params, int cfg_n_trees = 1, bool cfg_bootstrap = true, float cfg_rows_sample = 1.0f, - int cfg_n_streams = 8); + int cfg_seed = -1, int cfg_n_streams = 8); void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_rows_sample, int cfg_n_streams, + float cfg_rows_sample, int cfg_seed, int cfg_n_streams, DecisionTree::DecisionTreeParams cfg_tree_params); void validity_check(const RF_params rf_params); void print(const RF_params rf_params); @@ -154,8 +158,9 @@ RF_metrics score(const cumlHandle& user_handle, RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, int n_bins, int split_algo, int min_rows_per_node, bool bootstrap_features, bool bootstrap, int n_trees, - float rows_sample, CRITERION split_criterion, - bool quantile_per_tree, int cfg_n_streams); + float rows_sample, int seed, + CRITERION split_criterion, bool quantile_per_tree, + int cfg_n_streams); // ----------------------------- Regression ----------------------------------- // diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index a1de148a10..471e6da754 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -72,7 +72,10 @@ void rf::prepare_fit_per_tree( int tree_id, int n_rows, int n_sampled_rows, unsigned int* selected_rows, const int num_sms, const cudaStream_t stream, const std::shared_ptr device_allocator) { - srand(tree_id * 1000); + int rs = tree_id; + if (rf_params.seed > -1) rs = rf_params.seed + tree_id; + + srand(rs * 1000); if (rf_params.bootstrap) { random_uniformInt(tree_id, selected_rows, n_sampled_rows, n_rows, num_sms, stream); @@ -221,10 +224,9 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input, unsigned int* rowids; rowids = selected_rows[stream_id]->data(); - this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids, - tempmem[stream_id]->num_sms, - tempmem[stream_id]->stream, - handle.getDeviceAllocator()); + this->prepare_fit_per_tree( + i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms, + tempmem[stream_id]->stream, handle.getDeviceAllocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -236,8 +238,7 @@ void rfClassifier::fit(const cumlHandle& user_handle, const T* input, */ DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]); tree_ptr->treeid = i; - trees[i].fit(handle.getDeviceAllocator(), - handle.getHostAllocator(), + trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(), tempmem[stream_id]->stream, input, n_cols, n_rows, labels, rowids, n_sampled_rows, n_unique_labels, tree_ptr, this->rf_params.tree_params, tempmem[stream_id]); @@ -485,10 +486,9 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input, for (int i = 0; i < this->rf_params.n_trees; i++) { int stream_id = omp_get_thread_num(); unsigned int* rowids = selected_rows[stream_id]->data(); - this->prepare_fit_per_tree(i, n_rows, n_sampled_rows, rowids, - tempmem[stream_id]->num_sms, - tempmem[stream_id]->stream, - handle.getDeviceAllocator()); + this->prepare_fit_per_tree( + i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms, + tempmem[stream_id]->stream, handle.getDeviceAllocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -499,8 +499,7 @@ void rfRegressor::fit(const cumlHandle& user_handle, const T* input, */ DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]); tree_ptr->treeid = i; - trees[i].fit(handle.getDeviceAllocator(), - handle.getHostAllocator(), + trees[i].fit(handle.getDeviceAllocator(), handle.getHostAllocator(), tempmem[stream_id]->stream, input, n_cols, n_rows, labels, rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params, tempmem[stream_id]); diff --git a/cpp/src_prims/linalg/binary_op.h b/cpp/src_prims/linalg/binary_op.h index 9f799e076c..a7ad986b7a 100644 --- a/cpp/src_prims/linalg/binary_op.h +++ b/cpp/src_prims/linalg/binary_op.h @@ -22,37 +22,41 @@ namespace MLCommon { namespace LinAlg { -template -__global__ void binaryOpKernel(math_t *out, const math_t *in1, - const math_t *in2, IdxType len, Lambda op) { - typedef TxN_t VecType; - VecType a, b; +template +__global__ void binaryOpKernel(OutType *out, const InType *in1, + const InType *in2, IdxType len, Lambda op) { + typedef TxN_t InVecType; + typedef TxN_t OutVecType; + InVecType a, b; + OutVecType c; IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * blockDim.x); - idx *= VecType::Ratio; + idx *= InVecType::Ratio; if (idx >= len) return; a.load(in1, idx); b.load(in2, idx); #pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) { - a.val.data[i] = op(a.val.data[i], b.val.data[i]); + for (int i = 0; i < InVecType::Ratio; ++i) { + c.val.data[i] = op(a.val.data[i], b.val.data[i]); } - a.store(out, idx); + c.store(out, idx); } -template -void binaryOpImpl(math_t *out, const math_t *in1, const math_t *in2, +template +void binaryOpImpl(OutType *out, const InType *in1, const InType *in2, IdxType len, Lambda op, cudaStream_t stream) { - const IdxType nblks = ceildiv(veclen_ ? len / veclen_ : len, (IdxType)TPB); - binaryOpKernel + const IdxType nblks = ceildiv(VecLen ? len / VecLen : len, (IdxType)TPB); + binaryOpKernel <<>>(out, in1, in2, len, op); CUDA_CHECK(cudaPeekAtLastError()); } /** * @brief perform element-wise binary operation on the input arrays - * @tparam math_t data-type upon which the math operation will be performed + * @tparam InType input data-type * @tparam Lambda the device-lambda performing the actual operation + * @tparam OutType output data-type * @tparam IdxType Integer type used to for addressing * @tparam TPB threads-per-block in the final kernel launched * @param out the output array @@ -61,30 +65,34 @@ void binaryOpImpl(math_t *out, const math_t *in1, const math_t *in2, * @param len number of elements in the input array * @param op the device-lambda * @param stream cuda stream where to launch work + * @note Lambda must be a functor with the following signature: + * `OutType func(const InType& val1, const InType& val2);` */ -template -void binaryOp(math_t *out, const math_t *in1, const math_t *in2, IdxType len, +template +void binaryOp(OutType *out, const InType *in1, const InType *in2, IdxType len, Lambda op, cudaStream_t stream) { - size_t bytes = len * sizeof(math_t); - if (16 / sizeof(math_t) && bytes % 16 == 0) { - binaryOpImpl( + constexpr auto maxSize = + sizeof(InType) > sizeof(OutType) ? sizeof(InType) : sizeof(OutType); + size_t bytes = len * maxSize; + if (16 / maxSize && bytes % 16 == 0) { + binaryOpImpl( out, in1, in2, len, op, stream); - } else if (8 / sizeof(math_t) && bytes % 8 == 0) { - binaryOpImpl( + } else if (8 / maxSize && bytes % 8 == 0) { + binaryOpImpl( out, in1, in2, len, op, stream); - } else if (4 / sizeof(math_t) && bytes % 4 == 0) { - binaryOpImpl( + } else if (4 / maxSize && bytes % 4 == 0) { + binaryOpImpl( out, in1, in2, len, op, stream); - } else if (2 / sizeof(math_t) && bytes % 2 == 0) { - binaryOpImpl( + } else if (2 / maxSize && bytes % 2 == 0) { + binaryOpImpl( out, in1, in2, len, op, stream); - } else if (1 / sizeof(math_t)) { - binaryOpImpl( + } else if (1 / maxSize) { + binaryOpImpl( out, in1, in2, len, op, stream); } else { - binaryOpImpl(out, in1, in2, len, op, - stream); + binaryOpImpl(out, in1, in2, len, + op, stream); } } diff --git a/cpp/src_prims/linalg/unary_op.h b/cpp/src_prims/linalg/unary_op.h index 24860e6fae..35b1fdd9cd 100644 --- a/cpp/src_prims/linalg/unary_op.h +++ b/cpp/src_prims/linalg/unary_op.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, NVIDIA CORPORATION. + * Copyright (c) 2018-2019, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,36 +22,40 @@ namespace MLCommon { namespace LinAlg { -template -__global__ void unaryOpKernel(math_t *out, const math_t *in, IdxType len, +template +__global__ void unaryOpKernel(OutType *out, const InType *in, IdxType len, Lambda op) { - typedef TxN_t VecType; - VecType a; + typedef TxN_t InVecType; + typedef TxN_t OutVecType; + InVecType a; + OutVecType b; IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * blockDim.x); - idx *= VecType::Ratio; + idx *= InVecType::Ratio; if (idx >= len) return; a.load(in, idx); #pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) { - a.val.data[i] = op(a.val.data[i]); + for (int i = 0; i < InVecType::Ratio; ++i) { + b.val.data[i] = op(a.val.data[i]); } - a.store(out, idx); + b.store(out, idx); } -template -void unaryOpImpl(math_t *out, const math_t *in, IdxType len, Lambda op, +template +void unaryOpImpl(OutType *out, const InType *in, IdxType len, Lambda op, cudaStream_t stream) { - const IdxType nblks = ceildiv(veclen_ ? len / veclen_ : len, (IdxType)TPB); - unaryOpKernel + const IdxType nblks = ceildiv(VecLen ? len / VecLen : len, (IdxType)TPB); + unaryOpKernel <<>>(out, in, len, op); CUDA_CHECK(cudaPeekAtLastError()); } /** * @brief perform element-wise unary operation in the input array - * @tparam math_t data-type upon which the math operation will be performed + * @tparam InType input data-type * @tparam Lambda the device-lambda performing the actual operation + * @tparam OutType output data-type * @tparam IdxType Integer type used to for addressing * @tparam TPB threads-per-block in the final kernel launched * @param out the output array @@ -59,36 +63,41 @@ void unaryOpImpl(math_t *out, const math_t *in, IdxType len, Lambda op, * @param len number of elements in the input array * @param op the device-lambda * @param stream cuda stream where to launch work + * @note Lambda must be a functor with the following signature: + * `OutType func(const InType& val);` */ -template -void unaryOp(math_t *out, const math_t *in, IdxType len, Lambda op, +template +void unaryOp(OutType *out, const InType *in, IdxType len, Lambda op, cudaStream_t stream) { if (len <= 0) return; //silently skip in case of 0 length input - size_t bytes = len * sizeof(math_t); + constexpr auto maxSize = + sizeof(InType) >= sizeof(OutType) ? sizeof(InType) : sizeof(OutType); + size_t bytes = len * maxSize; uint64_t inAddr = uint64_t(in); uint64_t outAddr = uint64_t(out); - if (16 / sizeof(math_t) && bytes % 16 == 0 && inAddr % 16 == 0 && + if (16 / maxSize && bytes % 16 == 0 && inAddr % 16 == 0 && outAddr % 16 == 0) { - unaryOpImpl(out, in, len, - op, stream); - } else if (8 / sizeof(math_t) && bytes % 8 == 0 && inAddr % 8 == 0 && + unaryOpImpl( + out, in, len, op, stream); + } else if (8 / maxSize && bytes % 8 == 0 && inAddr % 8 == 0 && outAddr % 8 == 0) { - unaryOpImpl(out, in, len, - op, stream); - } else if (4 / sizeof(math_t) && bytes % 4 == 0 && inAddr % 4 == 0 && + unaryOpImpl( + out, in, len, op, stream); + } else if (4 / maxSize && bytes % 4 == 0 && inAddr % 4 == 0 && outAddr % 4 == 0) { - unaryOpImpl(out, in, len, - op, stream); - } else if (2 / sizeof(math_t) && bytes % 2 == 0 && inAddr % 2 == 0 && + unaryOpImpl( + out, in, len, op, stream); + } else if (2 / maxSize && bytes % 2 == 0 && inAddr % 2 == 0 && outAddr % 2 == 0) { - unaryOpImpl(out, in, len, - op, stream); - } else if (1 / sizeof(math_t)) { - unaryOpImpl(out, in, len, - op, stream); + unaryOpImpl( + out, in, len, op, stream); + } else if (1 / maxSize) { + unaryOpImpl( + out, in, len, op, stream); } else { - unaryOpImpl(out, in, len, op, stream); + unaryOpImpl(out, in, len, op, + stream); } } diff --git a/cpp/test/prims/binary_op.cu b/cpp/test/prims/binary_op.cu index 917df565d6..c0bac8b438 100644 --- a/cpp/test/prims/binary_op.cu +++ b/cpp/test/prims/binary_op.cu @@ -26,19 +26,21 @@ namespace LinAlg { // Or else, we get the following compilation error // for an extended __device__ lambda cannot have private or protected access // within its class -template -void binaryOpLaunch(T *out, const T *in1, const T *in2, IdxType len, - cudaStream_t stream) { +template +void binaryOpLaunch(OutType *out, const InType *in1, const InType *in2, + IdxType len, cudaStream_t stream) { binaryOp( - out, in1, in2, len, [] __device__(T a, T b) { return a + b; }, stream); + out, in1, in2, len, [] __device__(InType a, InType b) { return a + b; }, + stream); } -template +template class BinaryOpTest - : public ::testing::TestWithParam> { + : public ::testing::TestWithParam> { protected: void SetUp() override { - params = ::testing::TestWithParam>::GetParam(); + params = ::testing::TestWithParam< + BinaryOpInputs>::GetParam(); Random::Rng r(params.seed); cudaStream_t stream; CUDA_CHECK(cudaStreamCreate(&stream)); @@ -47,8 +49,8 @@ class BinaryOpTest allocate(in2, len); allocate(out_ref, len); allocate(out, len); - r.uniform(in1, len, T(-1.0), T(1.0), stream); - r.uniform(in2, len, T(-1.0), T(1.0), stream); + r.uniform(in1, len, InType(-1.0), InType(1.0), stream); + r.uniform(in2, len, InType(-1.0), InType(1.0), stream); naiveAdd(out_ref, in1, in2, len); binaryOpLaunch(out, in1, in2, len, stream); CUDA_CHECK(cudaStreamDestroy(stream)); @@ -62,8 +64,9 @@ class BinaryOpTest } protected: - BinaryOpInputs params; - T *in1, *in2, *out_ref, *out; + BinaryOpInputs params; + InType *in1, *in2; + OutType *out_ref, *out; }; const std::vector> inputsf_i32 = { @@ -86,6 +89,16 @@ TEST_P(BinaryOpTestF_i64, Result) { INSTANTIATE_TEST_CASE_P(BinaryOpTests, BinaryOpTestF_i64, ::testing::ValuesIn(inputsf_i64)); +const std::vector> inputsf_i32_d = { + {0.000001f, 1024 * 1024, 1234ULL}}; +typedef BinaryOpTest BinaryOpTestF_i32_D; +TEST_P(BinaryOpTestF_i32_D, Result) { + ASSERT_TRUE(devArrMatch(out_ref, out, params.len, + CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(BinaryOpTests, BinaryOpTestF_i32_D, + ::testing::ValuesIn(inputsf_i32_d)); + const std::vector> inputsd_i32 = { {0.00000001, 1024 * 1024, 1234ULL}}; typedef BinaryOpTest BinaryOpTestD_i32; diff --git a/cpp/test/prims/binary_op.h b/cpp/test/prims/binary_op.h index 30d0b3f42b..b1d25d81e2 100644 --- a/cpp/test/prims/binary_op.h +++ b/cpp/test/prims/binary_op.h @@ -22,33 +22,33 @@ namespace MLCommon { namespace LinAlg { -template -__global__ void naiveAddKernel(Type *out, const Type *in1, const Type *in2, - IdxType len) { +template +__global__ void naiveAddKernel(OutType *out, const InType *in1, + const InType *in2, IdxType len) { IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * (IdxType)blockDim.x); if (idx < len) { - out[idx] = in1[idx] + in2[idx]; + out[idx] = static_cast(in1[idx] + in2[idx]); } } -template -void naiveAdd(Type *out, const Type *in1, const Type *in2, IdxType len) { +template +void naiveAdd(OutType *out, const InType *in1, const InType *in2, IdxType len) { static const IdxType TPB = 64; IdxType nblks = ceildiv(len, TPB); - naiveAddKernel<<>>(out, in1, in2, len); + naiveAddKernel<<>>(out, in1, in2, len); CUDA_CHECK(cudaPeekAtLastError()); } -template +template struct BinaryOpInputs { - T tolerance; + InType tolerance; IdxType len; unsigned long long int seed; }; -template +template ::std::ostream &operator<<(::std::ostream &os, - const BinaryOpInputs &dims) { + const BinaryOpInputs &d) { return os; } diff --git a/cpp/test/prims/unary_op.cu b/cpp/test/prims/unary_op.cu index 62d925fbc7..ca9d0d25c9 100644 --- a/cpp/test/prims/unary_op.cu +++ b/cpp/test/prims/unary_op.cu @@ -26,29 +26,31 @@ namespace LinAlg { // Or else, we get the following compilation error // for an extended __device__ lambda cannot have private or protected access // within its class -template -void unaryOpLaunch(T *out, const T *in, T scalar, IdxType len, +template +void unaryOpLaunch(OutType *out, const InType *in, InType scalar, IdxType len, cudaStream_t stream) { - unaryOp( - out, in, len, [scalar] __device__(T in) { return in * scalar; }, stream); + auto op = [scalar] __device__(InType in) { + return static_cast(in * scalar); + }; + unaryOp(out, in, len, op, stream); } -template -class UnaryOpTest : public ::testing::TestWithParam> { +template +class UnaryOpTest + : public ::testing::TestWithParam> { protected: void SetUp() override { - params = ::testing::TestWithParam>::GetParam(); + params = ::testing::TestWithParam< + UnaryOpInputs>::GetParam(); Random::Rng r(params.seed); cudaStream_t stream; CUDA_CHECK(cudaStreamCreate(&stream)); - auto len = params.len; auto scalar = params.scalar; - allocate(in, len); allocate(out_ref, len); allocate(out, len); - r.uniform(in, len, T(-1.0), T(1.0), stream); + r.uniform(in, len, InType(-1.0), InType(1.0), stream); naiveScale(out_ref, in, scalar, len, stream); unaryOpLaunch(out, in, scalar, len, stream); CUDA_CHECK(cudaStreamDestroy(stream)); @@ -61,8 +63,9 @@ class UnaryOpTest : public ::testing::TestWithParam> { } protected: - UnaryOpInputs params; - T *in, *out_ref, *out; + UnaryOpInputs params; + InType *in; + OutType *out_ref, *out; }; const std::vector> inputsf_i32 = { @@ -85,6 +88,16 @@ TEST_P(UnaryOpTestF_i64, Result) { INSTANTIATE_TEST_CASE_P(UnaryOpTests, UnaryOpTestF_i64, ::testing::ValuesIn(inputsf_i64)); +const std::vector> inputsf_i32_d = { + {0.000001f, 1024 * 1024, 2.f, 1234ULL}}; +typedef UnaryOpTest UnaryOpTestF_i32_D; +TEST_P(UnaryOpTestF_i32_D, Result) { + ASSERT_TRUE(devArrMatch(out_ref, out, params.len, + CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(UnaryOpTests, UnaryOpTestF_i32_D, + ::testing::ValuesIn(inputsf_i32_d)); + const std::vector> inputsd_i32 = { {0.00000001, 1024 * 1024, 2.0, 1234ULL}}; typedef UnaryOpTest UnaryOpTestD_i32; diff --git a/cpp/test/prims/unary_op.h b/cpp/test/prims/unary_op.h index bc08a56d59..24e5e2ec1c 100644 --- a/cpp/test/prims/unary_op.h +++ b/cpp/test/prims/unary_op.h @@ -22,35 +22,36 @@ namespace MLCommon { namespace LinAlg { -template -__global__ void naiveScaleKernel(Type *out, const Type *in, Type scalar, +template +__global__ void naiveScaleKernel(OutType *out, const InType *in, InType scalar, IdxType len) { IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * (IdxType)blockDim.x); if (idx < len) { - out[idx] = scalar * in[idx]; + out[idx] = static_cast(scalar * in[idx]); } } -template -void naiveScale(Type *out, const Type *in, Type scalar, int len, +template +void naiveScale(OutType *out, const InType *in, InType scalar, int len, cudaStream_t stream) { static const int TPB = 64; int nblks = ceildiv(len, TPB); - naiveScaleKernel<<>>(out, in, scalar, len); + naiveScaleKernel + <<>>(out, in, scalar, len); CUDA_CHECK(cudaPeekAtLastError()); } -template +template struct UnaryOpInputs { - T tolerance; + InType tolerance; IdxType len; - T scalar; + InType scalar; unsigned long long int seed; }; -template +template ::std::ostream &operator<<(::std::ostream &os, - const UnaryOpInputs &dims) { + const UnaryOpInputs &d) { return os; } diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index 512c53c2af..115cbd8e97 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -61,7 +61,7 @@ class RfClassifierTest : public ::testing::TestWithParam> { params.split_criterion, false); RF_params rf_params; set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.rows_sample, params.n_streams, tree_params); + params.rows_sample, -1, params.n_streams, tree_params); //print(rf_params); //-------------------------------------------------------- @@ -161,7 +161,7 @@ class RfRegressorTest : public ::testing::TestWithParam> { params.split_criterion, false); RF_params rf_params; set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.rows_sample, params.n_streams, tree_params); + params.rows_sample, -1, params.n_streams, tree_params); //print(rf_params); //-------------------------------------------------------- diff --git a/cpp/test/sg/rf_treelite_test.cu b/cpp/test/sg/rf_treelite_test.cu index d772b0c83e..279db9588e 100644 --- a/cpp/test/sg/rf_treelite_test.cu +++ b/cpp/test/sg/rf_treelite_test.cu @@ -181,7 +181,7 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam> { params.min_rows_per_node, params.bootstrap_features, params.split_criterion, false); set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.rows_sample, params.n_streams, tree_params); + params.rows_sample, -1, params.n_streams, tree_params); // print(rf_params); handle.reset(new cumlHandle(rf_params.n_streams)); diff --git a/docs/source/api.rst b/docs/source/api.rst index fce956891b..ce0fb8d750 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -4,8 +4,8 @@ cuML API Reference -Preprocessing -============== +Preprocessing, Metrics, and Utilities +===================================== Model Selection and Data Splitting ---------------------------------- @@ -24,6 +24,44 @@ Dataset Generation .. automethod:: cuml.datasets.make_blobs +Metrics +--------- + + .. automodule:: cuml.metrics.regression + :members: + + .. automodule:: cuml.metrics.accuracy + :members: + + .. automodule:: cuml.metrics.trustworthiness + :members: + + .. automodule:: cuml.metrics.cluster + :members: + +Benchmarking +------------- + + .. automodule:: cuml.benchmark.algorithms + :members: + + .. automodule:: cuml.benchmark.runners + :members: + + .. automodule:: cuml.benchmark.datagen + :members: + + + +Utilities for I/O and Numba +--------------------------- + + .. automodule:: cuml.utils.input_utils + :members: + + .. automodule:: cuml.utils.numba_utils + :members: + Regression and Classification ============================= @@ -84,6 +122,12 @@ Quasi-Newton .. autoclass:: cuml.QN :members: +Support Vector Machines +------------------------ + +.. autoclass:: cuml.svm.SVC + :members: + Clustering ========== diff --git a/python/cuml/benchmark/datagen.py b/python/cuml/benchmark/datagen.py index 8a881687ad..203d6a6759 100644 --- a/python/cuml/benchmark/datagen.py +++ b/python/cuml/benchmark/datagen.py @@ -22,7 +22,7 @@ * n_samples (set to 0 for 'default') * n_features (set to 0 for 'default') * random_state - * .. and optional generator-specific parameters + * (and optional generator-specific parameters) The function should return a 2-tuple (X, y), where X is a Pandas dataframe and y is a Pandas series. If the generator does not produce diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index 802acdc907..aa1cd53fab 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -24,7 +24,7 @@ import cudf import numpy as np import warnings -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/cluster/kmeans_mg.pyx b/python/cuml/cluster/kmeans_mg.pyx index d04d448f7a..57785a4a5d 100644 --- a/python/cuml/cluster/kmeans_mg.pyx +++ b/python/cuml/cluster/kmeans_mg.pyx @@ -24,7 +24,7 @@ import cudf import numpy as np import warnings -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index 9e063e18a4..8187af137a 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -194,6 +194,12 @@ def __init__( self.n_estimators_per_worker[i] + 1 ) + seeds = list() + seeds.append(0) + for i in range(1, len(self.n_estimators_per_worker)): + sd = self.n_estimators_per_worker[i-1] + seeds[i-1] + seeds.append(sd) + key = str(uuid1()) self.rfs = { worker: c.submit( @@ -213,6 +219,7 @@ def __init__( rows_sample, max_leaves, quantile_per_tree, + seeds[n], dtype, key="%s-%s" % (key, n), workers=[worker], @@ -243,6 +250,7 @@ def _func_build_rf( rows_sample, max_leaves, quantile_per_tree, + seed, dtype, ): return cuRFC( @@ -262,6 +270,7 @@ def _func_build_rf( max_leaves=max_leaves, n_streams=n_streams, quantile_per_tree=quantile_per_tree, + seed=seed, gdf_datatype=dtype, ) diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index b86e1a9269..e12bb75056 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -197,6 +197,12 @@ def __init__( self.n_estimators_per_worker[i] + 1 ) + seeds = list() + seeds.append(0) + for i in range(1, len(self.n_estimators_per_worker)): + sd = self.n_estimators_per_worker[i-1] + seeds[i-1] + seeds.append(sd) + key = str(uuid1()) self.rfs = { worker: c.submit( @@ -216,6 +222,7 @@ def __init__( max_leaves, accuracy_metric, quantile_per_tree, + seeds[n], key="%s-%s" % (key, n), workers=[worker], ) @@ -245,6 +252,7 @@ def _func_build_rf( max_leaves, accuracy_metric, quantile_per_tree, + seed, ): return cuRFR( @@ -264,6 +272,7 @@ def _func_build_rf( n_streams=n_streams, accuracy_metric=accuracy_metric, quantile_per_tree=quantile_per_tree, + seed=seed, ) @staticmethod diff --git a/python/cuml/dask/linear_model/linear_regression.py b/python/cuml/dask/linear_model/linear_regression.py index 1034af5489..a860c00b6f 100644 --- a/python/cuml/dask/linear_model/linear_regression.py +++ b/python/cuml/dask/linear_model/linear_regression.py @@ -26,7 +26,7 @@ from dask import delayed from dask.distributed import wait, default_client from math import ceil -from librmm_cffi import librmm as rmm +import rmm from toolz import first from tornado import gen diff --git a/python/cuml/dask/neighbors/nearest_neighbors.py b/python/cuml/dask/neighbors/nearest_neighbors.py index b37ee0568c..7128eb2dd1 100644 --- a/python/cuml/dask/neighbors/nearest_neighbors.py +++ b/python/cuml/dask/neighbors/nearest_neighbors.py @@ -24,7 +24,7 @@ import random from cuml.utils import numba_utils -from librmm_cffi import librmm as rmm +import rmm from dask import delayed from collections import defaultdict diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx index 78aab743ad..cd28e929ed 100644 --- a/python/cuml/decomposition/pca.pyx +++ b/python/cuml/decomposition/pca.pyx @@ -23,7 +23,7 @@ import ctypes import cudf import numpy as np -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx index eb10bf9565..66e4866048 100644 --- a/python/cuml/decomposition/tsvd.pyx +++ b/python/cuml/decomposition/tsvd.pyx @@ -23,7 +23,7 @@ import ctypes import cudf import numpy as np -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 04f848a92d..4108a37b8d 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -81,6 +81,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": int n_trees bool bootstrap float rows_sample + int seed pass cdef cppclass RandomForestMetaData[T, L]: @@ -181,6 +182,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": bool, int, float, + int, CRITERION, bool, int) except + @@ -302,7 +304,8 @@ class RandomForestClassifier(Base): min_samples_leaf=None, min_weight_fraction_leaf=None, max_leaf_nodes=None, min_impurity_decrease=None, min_impurity_split=None, oob_score=None, n_jobs=None, - random_state=None, warm_start=None, class_weight=None): + random_state=None, warm_start=None, class_weight=None, + seed=-1): sklearn_params = {"criterion": criterion, "min_samples_leaf": min_samples_leaf, @@ -353,6 +356,7 @@ class RandomForestClassifier(Base): self.quantile_per_tree = quantile_per_tree self.n_cols = None self.n_streams = n_streams + self.seed = seed cdef RandomForestMetaData[float, int] *rf_forest = \ new RandomForestMetaData[float, int]() @@ -497,6 +501,7 @@ class RandomForestClassifier(Base): self.bootstrap, self.n_estimators, self.rows_sample, + self.seed, self.split_criterion, self.quantile_per_tree, self.n_streams) @@ -608,6 +613,7 @@ class RandomForestClassifier(Base): num_classes=2): """ Predicts the labels for X. + Parameters ---------- X : array-like (device or host) shape = (n_samples, n_features) @@ -639,9 +645,10 @@ class RandomForestClassifier(Base): It is applied if output_class == True, else it is ignored num_classes : integer number of different classes present in the dataset + Returns ---------- - y: NumPy + y : NumPy Dense vector (int) of shape (n_samples, 1) """ if self.dtype == np.float64: @@ -662,15 +669,17 @@ class RandomForestClassifier(Base): def _predict_get_all(self, X): """ Predicts the labels for X. + Parameters ---------- X : array-like (device or host) shape = (n_samples, n_features) Dense matrix (floats or doubles) of shape (n_samples, n_features). Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device ndarray, cuda array interface compliant array like CuPy + Returns ---------- - y: NumPy + y : NumPy Dense vector (int) of shape (n_samples, 1) """ cdef uintptr_t X_ptr @@ -724,17 +733,20 @@ class RandomForestClassifier(Base): def score(self, X, y): """ Calculates the accuracy metric score of the model for X. + Parameters ---------- X : array-like (device or host) shape = (n_samples, n_features) Dense matrix (floats or doubles) of shape (n_samples, n_features). Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device ndarray, cuda array interface compliant array like CuPy - y: NumPy + y : NumPy Dense vector (int) of shape (n_samples, 1) + Returns - ---------- - accuracy of the model + ------- + float + Accuracy of the model [0.0 - 1.0] """ cdef uintptr_t X_ptr, y_ptr X_m, X_ptr, n_rows, n_cols, _ = \ @@ -795,6 +807,7 @@ class RandomForestClassifier(Base): """ Returns the value of all parameters required to configure this estimator as a dictionary. + Parameters ----------- deep : boolean (default = True) @@ -811,6 +824,7 @@ class RandomForestClassifier(Base): Sets the value of parameters required to configure this estimator, it functions similar to the sklearn set_params. + Parameters ----------- params : dict of new params diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 0b9009b510..3c675c74ef 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -80,6 +80,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": int n_trees bool bootstrap float rows_sample + int seed pass cdef cppclass RandomForestMetaData[T, L]: @@ -162,6 +163,7 @@ cdef extern from "randomforest/randomforest.hpp" namespace "ML": bool, int, float, + int, CRITERION, bool, int) except + @@ -286,7 +288,7 @@ class RandomForestRegressor(Base): max_leaf_nodes=None, min_impurity_decrease=None, min_impurity_split=None, oob_score=None, random_state=None, warm_start=None, class_weight=None, - quantile_per_tree=False, criterion=None): + quantile_per_tree=False, criterion=None, seed=-1): sklearn_params = {"criterion": criterion, "min_samples_leaf": min_samples_leaf, @@ -337,6 +339,7 @@ class RandomForestRegressor(Base): self.accuracy_metric = accuracy_metric self.quantile_per_tree = quantile_per_tree self.n_streams = n_streams + self.seed = seed cdef RandomForestMetaData[float, float] *rf_forest = \ new RandomForestMetaData[float, float]() @@ -461,6 +464,7 @@ class RandomForestRegressor(Base): self.bootstrap, self.n_estimators, self.rows_sample, + self.seed, self.split_criterion, self.quantile_per_tree, self.n_streams) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 58ecb9645a..d55983abda 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -26,7 +26,7 @@ import math import numpy as np import warnings -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/filter/kalman_filter.pyx b/python/cuml/filter/kalman_filter.pyx index 7308fa14ab..22bf510a33 100644 --- a/python/cuml/filter/kalman_filter.pyx +++ b/python/cuml/filter/kalman_filter.pyx @@ -24,7 +24,7 @@ import numpy as np from numba import cuda from cuml.utils import numba_utils -from librmm_cffi import librmm as rmm +import rmm from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 9ee124b97f..cb27bbf9af 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -32,7 +32,7 @@ from cuml.common.base import Base from cuml.common.handle cimport cumlHandle from cuml.utils import input_to_dev_array as to_cuda -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 1f120c0685..25a42669a5 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -31,7 +31,7 @@ from cuml.common.handle cimport cumlHandle from cuml.utils import get_cudf_column_ptr, get_dev_array_ptr, \ input_to_dev_array, zeros, row_matrix -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/metrics/accuracy.pyx b/python/cuml/metrics/accuracy.pyx index dfebe1910f..5ce62bb591 100644 --- a/python/cuml/metrics/accuracy.pyx +++ b/python/cuml/metrics/accuracy.pyx @@ -42,13 +42,16 @@ def accuracy_score(ground_truth, predictions, handle=None): Parameters ---------- - handle : cuml.Handle - prediction : The lables predicted by the model - for the test dataset - ground_truth : The ground truth labels of the test dataset + handle : cuml.Handle + prediction : NumPy ndarray or Numba device + The lables predicted by the model for the test dataset + ground_truth : NumPy ndarray, Numba device + The ground truth labels of the test dataset + Returns ------- - The accuracy of the model used for prediction + float + The accuracy of the model used for prediction """ handle = cuml.common.handle.Handle() \ if handle is None else handle diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 7e21a1aa87..49812e76f4 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -35,14 +35,14 @@ from cython.operator cimport dereference as deref from libcpp cimport bool from libcpp.memory cimport shared_ptr -from librmm_cffi import librmm as rmm +import rmm from libc.stdlib cimport malloc, free from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free from numba import cuda -from librmm_cffi import librmm as rmm +import rmm cimport cuml.common.handle cimport cuml.common.cuda diff --git a/python/cuml/preprocessing/LabelEncoder.py b/python/cuml/preprocessing/LabelEncoder.py index a62765b4c5..f5f28a0f27 100644 --- a/python/cuml/preprocessing/LabelEncoder.py +++ b/python/cuml/preprocessing/LabelEncoder.py @@ -16,7 +16,7 @@ import cudf import nvcategory -from librmm_cffi import librmm +import rmm import numpy as np @@ -194,7 +194,7 @@ def fit_transform(self, y: cudf.Series) -> cudf.Series: self._cats = nvcategory.from_strings(y.data) self._fitted = True - arr: librmm.device_array = librmm.device_array( + arr: rmm.device_array = rmm.device_array( y.data.size(), dtype=np.int32 ) self._cats.values(devptr=arr.device_ctypes_pointer.value) diff --git a/python/cuml/random_projection/random_projection.pyx b/python/cuml/random_projection/random_projection.pyx index 7897eecc61..efa6ba8193 100644 --- a/python/cuml/random_projection/random_projection.pyx +++ b/python/cuml/random_projection/random_projection.pyx @@ -22,7 +22,7 @@ import cudf import numpy as np -from librmm_cffi import librmm as rmm +import rmm from libc.stdint cimport uintptr_t from libcpp cimport bool diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index f783590dfe..393ab2b379 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -23,7 +23,7 @@ import cudf import numpy as np import warnings -from librmm_cffi import librmm as rmm +import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t diff --git a/python/cuml/svm/svm.pyx b/python/cuml/svm/svm.pyx index 141e8a5002..8ebf541f49 100644 --- a/python/cuml/svm/svm.pyx +++ b/python/cuml/svm/svm.pyx @@ -442,9 +442,10 @@ class SVC(Base): Dense matrix (floats or doubles) of shape (n_samples, n_features). Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device ndarray, cuda array interface compliant array like CuPy + Returns - ---------- - y: cuDF Series + ------- + y : cuDF Series Dense vector (floats or doubles) of shape (n_samples, 1) """ diff --git a/python/cuml/utils/input_utils.py b/python/cuml/utils/input_utils.py index 6170ab4223..62987d7bb8 100644 --- a/python/cuml/utils/input_utils.py +++ b/python/cuml/utils/input_utils.py @@ -26,7 +26,7 @@ from collections.abc import Collection from numba import cuda -from librmm_cffi import librmm as rmm +import rmm inp_array = namedtuple('inp_array', 'array pointer n_rows n_cols dtype') @@ -71,14 +71,20 @@ def input_to_dev_array(X, order='F', deepcopy=False, check_cols=False, check_rows=False, fail_on_order=False): """ - Convert input X to device array suitable for C++ methods + Convert input X to device array suitable for C++ methods. + Acceptable input formats: + * cuDF Dataframe - returns a deep copy always + * cuDF Series - returns by reference or a deep copy depending on `deepcopy` + * Numpy array - returns a copy in device always + * cuda array interface compliant array (like Cupy) - returns a - reference unless deepcopy=True + reference unless `deepcopy`=True + * numba device array - returns a reference unless deepcopy=True Parameters @@ -309,13 +315,18 @@ def input_to_host_array(X, order='F', deepcopy=False, """ Convert input X to host array (NumPy) suitable for C++ methods that accept host arrays. + Acceptable input formats: + * Numpy array - returns a pointer to the original input + * cuDF Dataframe - returns a deep copy always - * cuDF Series - returns by reference or a deep copy depending on - `deepcopy` - * cuda array interface compliant array (like Cupy) - returns a + + * cuDF Series - returns by reference or a deep copy depending on `deepcopy` + + * cuda array interface compliant array (like Cupy) - returns a \ reference unless deepcopy=True + * numba device array - returns a reference unless deepcopy=True Parameters diff --git a/python/cuml/utils/numba_utils.py b/python/cuml/utils/numba_utils.py index 288063b165..dc6b2d8c00 100644 --- a/python/cuml/utils/numba_utils.py +++ b/python/cuml/utils/numba_utils.py @@ -18,7 +18,7 @@ from numba import cuda from numba.cuda.cudadrv.driver import driver -from librmm_cffi import librmm as rmm +import rmm import numpy as np