diff --git a/docs/Installation-Guide.rst b/docs/Installation-Guide.rst index 62383d9f924a..8e965210c1f5 100644 --- a/docs/Installation-Guide.rst +++ b/docs/Installation-Guide.rst @@ -624,9 +624,8 @@ Build CUDA Version The `original GPU build <#build-gpu-version>`__ of LightGBM (``device_type=gpu``) is based on OpenCL. -The CUDA-based build (``device_type=cuda``) is a separate implementation and requires an NVIDIA graphics card with compute capability 6.0 and higher. It should be considered experimental, and we suggest using it only when it is impossible to use OpenCL version (for example, on IBM POWER microprocessors). - -**Note**: only Linux is supported, other operating systems are not supported yet. +The CUDA-based build (``device_type=cuda``) is a separate implementation. +Use this version in Linux environments with an NVIDIA GPU with compute capability 6.0 or higher. Linux ^^^^^ @@ -654,6 +653,17 @@ To build LightGBM CUDA version, run the following commands: **Note**: In some rare cases you may need to install OpenMP runtime library separately (use your package manager and search for ``lib[g|i]omp`` for doing this). +macOS +^^^^^ + +The CUDA version is not supported on macOS. + +Windows +^^^^^^^ + +The CUDA version is not supported on Windows. +Use the GPU version (``device_type=gpu``) for GPU acceleration on Windows. + Build HDFS Version ~~~~~~~~~~~~~~~~~~ diff --git a/docs/_static/js/script.js b/docs/_static/js/script.js index d6f5b4125057..89d14d14aaf0 100644 --- a/docs/_static/js/script.js +++ b/docs/_static/js/script.js @@ -28,7 +28,7 @@ $(function() { '#build-threadless-version-not-recommended', '#build-mpi-version', '#build-gpu-version', - '#build-cuda-version-experimental', + '#build-cuda-version', '#build-hdfs-version', '#build-java-wrapper', '#build-c-unit-tests' diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 6500cb77272d..4bb4c394b03c 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -38,6 +38,10 @@ const int kDefaultNumLeaves = 31; struct Config { public: + Config() {} + explicit Config(std::unordered_map parameters_map) { + Set(parameters_map); + } std::string ToString() const; /*! * \brief Get string value by specific name of key diff --git a/python-package/README.rst b/python-package/README.rst index bf9874e1227c..c3b73ffdf5d1 100644 --- a/python-package/README.rst +++ b/python-package/README.rst @@ -153,7 +153,7 @@ Build CUDA Version All requirements from `Build from Sources section <#build-from-sources>`__ apply for this installation option as well, and `CMake`_ (version 3.16 or higher) is strongly required. -**CUDA** library (version 10.0 or higher) is needed: details for installation can be found in `Installation Guide `__. +**CUDA** library (version 10.0 or higher) is needed: details for installation can be found in `Installation Guide `__. To use the CUDA version within Python, pass ``{"device": "cuda"}`` respectively in parameters. diff --git a/src/c_api.cpp b/src/c_api.cpp index 67b18003588a..67a6d05b75a7 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -59,12 +59,12 @@ yamc::shared_lock lock(&mtx); const int PREDICTOR_TYPES = 4; // Single row predictor to abstract away caching logic -class SingleRowPredictor { +class SingleRowPredictorInner { public: PredictFunction predict_function; int64_t num_pred_in_one_row; - SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int start_iter, int num_iter) { + SingleRowPredictorInner(int predict_type, Boosting* boosting, const Config& config, int start_iter, int num_iter) { bool is_predict_leaf = false; bool is_raw_score = false; bool predict_contrib = false; @@ -86,7 +86,7 @@ class SingleRowPredictor { num_total_model_ = boosting->NumberOfTotalModel(); } - ~SingleRowPredictor() {} + ~SingleRowPredictorInner() {} bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) { return early_stop_ == config.pred_early_stop && @@ -105,6 +105,60 @@ class SingleRowPredictor { int num_total_model_; }; +/*! + * \brief Object to store resources meant for single-row Fast Predict methods. + * + * For legacy reasons this is called `FastConfig` in the public C API. + * + * Meant to be used by the *Fast* predict methods only. + * It stores the configuration and prediction resources for reuse across predictions. + */ +struct SingleRowPredictor { + public: + SingleRowPredictor(yamc::alternate::shared_mutex *booster_mutex, + const char *parameters, + const int data_type, + const int32_t num_cols, + int predict_type, + Boosting *boosting, + int start_iter, + int num_iter) : config(Config::Str2Map(parameters)), data_type(data_type), num_cols(num_cols), single_row_predictor_inner(predict_type, boosting, config, start_iter, num_iter), booster_mutex(booster_mutex) { + if (!config.predict_disable_shape_check && num_cols != boosting->MaxFeatureIdx() + 1) { + Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\ + "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", num_cols, boosting->MaxFeatureIdx() + 1); + } + } + + void Predict(std::function>(int row_idx)> get_row_fun, + double* out_result, int64_t* out_len) const { + UNIQUE_LOCK(single_row_predictor_mutex) + yamc::shared_lock booster_shared_lock(booster_mutex); + + auto one_row = get_row_fun(0); + single_row_predictor_inner.predict_function(one_row, out_result); + + *out_len = single_row_predictor_inner.num_pred_in_one_row; + } + + public: + Config config; + const int data_type; + const int32_t num_cols; + + private: + SingleRowPredictorInner single_row_predictor_inner; + + // Prevent the booster from being modified while we have a predictor relying on it during prediction + yamc::alternate::shared_mutex *booster_mutex; + + // If several threads try to predict at the same time using the same SingleRowPredictor + // we want them to still provide correct values, so the mutex is necessary due to the shared + // resources in the predictor. + // However the recommended approach is to instantiate one SingleRowPredictor per thread, + // to avoid contention here. + mutable yamc::alternate::shared_mutex single_row_predictor_mutex; +}; + class Booster { public: explicit Booster(const char* filename) { @@ -374,15 +428,26 @@ class Booster { boosting_->RollbackOneIter(); } - void SetSingleRowPredictor(int start_iteration, int num_iteration, int predict_type, const Config& config) { + void SetSingleRowPredictorInner(int start_iteration, int num_iteration, int predict_type, const Config& config) { UNIQUE_LOCK(mutex_) if (single_row_predictor_[predict_type].get() == nullptr || !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) { - single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(), + single_row_predictor_[predict_type].reset(new SingleRowPredictorInner(predict_type, boosting_.get(), config, start_iteration, num_iteration)); } } + std::unique_ptr InitSingleRowPredictor(int predict_type, int start_iteration, int num_iteration, int data_type, int32_t num_cols, const char *parameters) { + // Workaround https://github.com/microsoft/LightGBM/issues/6142 by locking here + // This is only a workaround because if predictors are initialized differently it may still behave incorrectly, + // and because multiple racing Predictor initializations through LGBM_BoosterPredictForMat suffers from that same issue of Predictor init writing things in the booster. + // Once #6142 is fixed (predictor doesn't write in the Booster as should have been the case since 1c35c3b9ede9adab8ccc5fd7b4b2b6af188a79f0), this line can be removed. + UNIQUE_LOCK(mutex_) + + return std::unique_ptr(new SingleRowPredictor( + &mutex_, parameters, data_type, num_cols, predict_type, boosting_.get(), start_iteration, num_iteration)); + } + void PredictSingleRow(int predict_type, int ncol, std::function>(int row_idx)> get_row_fun, const Config& config, @@ -815,7 +880,7 @@ class Booster { private: const Dataset* train_data_; std::unique_ptr boosting_; - std::unique_ptr single_row_predictor_[PREDICTOR_TYPES]; + std::unique_ptr single_row_predictor_[PREDICTOR_TYPES]; /*! \brief All configs */ Config config_; @@ -850,6 +915,7 @@ using LightGBM::Log; using LightGBM::Network; using LightGBM::Random; using LightGBM::ReduceScatterFunction; +using LightGBM::SingleRowPredictor; // some help functions used to convert data @@ -2163,35 +2229,15 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle, API_END(); } -/*! - * \brief Object to store resources meant for single-row Fast Predict methods. - * - * Meant to be used as a basic struct by the *Fast* predict methods only. - * It stores the configuration resources for reuse during prediction. - * - * Even the row function is stored. We score the instance at the same memory - * address all the time. One just replaces the feature values at that address - * and scores again with the *Fast* methods. - */ -struct FastConfig { - FastConfig(Booster *const booster_ptr, - const char *parameter, - const int predict_type_, - const int data_type_, - const int32_t num_cols) : booster(booster_ptr), predict_type(predict_type_), data_type(data_type_), ncol(num_cols) { - config.Set(Config::Str2Map(parameter)); - } - - Booster* const booster; - Config config; - const int predict_type; - const int data_type; - const int32_t ncol; -}; - +// Naming: In future versions of LightGBM, public API named around `FastConfig` should be made named around +// `SingleRowPredictor`, because it is specific to single row prediction, and doesn't actually hold only config. +// For now this is kept as `FastConfig` for backwards compatibility. +// At the same time, one should consider removing the old non-fast single row public API that stores its Predictor +// in the Booster, because that will enable removing these Predictors from the Booster, and associated initialization +// code. int LGBM_FastConfigFree(FastConfigHandle fastConfig) { API_BEGIN(); - delete reinterpret_cast(fastConfig); + delete reinterpret_cast(fastConfig); API_END(); } @@ -2339,7 +2385,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); - ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config); + ref_booster->SetSingleRowPredictorInner(start_iteration, num_iteration, predict_type, config); ref_booster->PredictSingleRow(predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -2359,18 +2405,14 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, Log::Fatal("The number of columns should be smaller than INT32_MAX."); } - auto fastConfig_ptr = std::unique_ptr(new FastConfig( - reinterpret_cast(handle), - parameter, - predict_type, - data_type, - static_cast(num_col))); + Booster* ref_booster = reinterpret_cast(handle); - OMP_SET_NUM_THREADS(fastConfig_ptr->config.num_threads); + std::unique_ptr single_row_predictor = + ref_booster->InitSingleRowPredictor(start_iteration, num_iteration, predict_type, data_type, static_cast(num_col), parameter); - fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config); + OMP_SET_NUM_THREADS(single_row_predictor->config.num_threads); - *out_fastConfig = fastConfig_ptr.release(); + *out_fastConfig = single_row_predictor.release(); API_END(); } @@ -2384,10 +2426,9 @@ int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle, int64_t* out_len, double* out_result) { API_BEGIN(); - FastConfig *fastConfig = reinterpret_cast(fastConfig_handle); - auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem); - fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol, - get_row_fun, fastConfig->config, out_result, out_len); + SingleRowPredictor *single_row_predictor = reinterpret_cast(fastConfig_handle); + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, single_row_predictor->data_type, nindptr, nelem); + single_row_predictor->Predict(get_row_fun, out_result, out_len); API_END(); } @@ -2502,7 +2543,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); - ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config); + ref_booster->SetSingleRowPredictorInner(start_iteration, num_iteration, predict_type, config); ref_booster->PredictSingleRow(predict_type, ncol, get_row_fun, config, out_result, out_len); API_END(); } @@ -2516,18 +2557,14 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, const char* parameter, FastConfigHandle *out_fastConfig) { API_BEGIN(); - auto fastConfig_ptr = std::unique_ptr(new FastConfig( - reinterpret_cast(handle), - parameter, - predict_type, - data_type, - ncol)); + Booster* ref_booster = reinterpret_cast(handle); - OMP_SET_NUM_THREADS(fastConfig_ptr->config.num_threads); + std::unique_ptr single_row_predictor = + ref_booster->InitSingleRowPredictor(predict_type, start_iteration, num_iteration, data_type, ncol, parameter); - fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config); + OMP_SET_NUM_THREADS(single_row_predictor->config.num_threads); - *out_fastConfig = fastConfig_ptr.release(); + *out_fastConfig = single_row_predictor.release(); API_END(); } @@ -2536,12 +2573,10 @@ int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle, int64_t* out_len, double* out_result) { API_BEGIN(); - FastConfig *fastConfig = reinterpret_cast(fastConfig_handle); + SingleRowPredictor *single_row_predictor = reinterpret_cast(fastConfig_handle); // Single row in row-major format: - auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1); - fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol, - get_row_fun, fastConfig->config, - out_result, out_len); + auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, single_row_predictor->num_cols, single_row_predictor->data_type, 1); + single_row_predictor->Predict(get_row_fun, out_result, out_len); API_END(); } diff --git a/tests/cpp_tests/test_single_row.cpp b/tests/cpp_tests/test_single_row.cpp new file mode 100644 index 000000000000..c14a681bea68 --- /dev/null +++ b/tests/cpp_tests/test_single_row.cpp @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2022 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ + +#include +#include +#include + +#include +#include + +using LightGBM::TestUtils; + +TEST(SingleRow, JustWorks) { + // Load some test data + int result; + + DatasetHandle train_dataset; + result = TestUtils::LoadDatasetFromExamples("binary_classification/binary.train", "max_bin=15", &train_dataset); + EXPECT_EQ(0, result) << "LoadDatasetFromExamples train result code: " << result; + + BoosterHandle booster_handle; + result = LGBM_BoosterCreate(train_dataset, "app=binary metric=auc num_leaves=31 verbose=0", &booster_handle); + EXPECT_EQ(0, result) << "LGBM_BoosterCreate result code: " << result; + + for (int i = 0; i < 51; i++) { + int is_finished; + result = LGBM_BoosterUpdateOneIter( + booster_handle, + &is_finished); + EXPECT_EQ(0, result) << "LGBM_BoosterUpdateOneIter result code: " << result; + } + + int n_features; + result = LGBM_BoosterGetNumFeature( + booster_handle, + &n_features); + EXPECT_EQ(0, result) << "LGBM_BoosterGetNumFeature result code: " << result; + + // Run a single row prediction and compare with regular Mat prediction: + int64_t output_size; + result = LGBM_BoosterCalcNumPredict( + booster_handle, + 1, + C_API_PREDICT_NORMAL, // predict_type + 0, // start_iteration + -1, // num_iteration + &output_size); + EXPECT_EQ(0, result) << "LGBM_BoosterCalcNumPredict result code: " << result; + + std::ifstream test_file("../examples/binary_classification/binary.test"); + std::vector test; + double x; + int test_set_size = 0; + while (test_file >> x) { + if (test_set_size % (n_features + 1) == 0) { + // Drop the result from the dataset, we only care about checking that prediction results are equal + // in both cases + test_file >> x; + test_set_size++; + } + test.push_back(x); + test_set_size++; + } + EXPECT_EQ(test_set_size % (n_features + 1), 0) << "Test size mismatch with dataset size (%)"; + test_set_size /= (n_features + 1); + EXPECT_EQ(test_set_size, 500) << "Improperly parsed test file (test_set_size)"; + EXPECT_EQ(test.size(), test_set_size * n_features) << "Improperly parsed test file (test len)"; + + std::vector mat_output(output_size * test_set_size, -1); + int64_t written; + result = LGBM_BoosterPredictForMat( + booster_handle, + &test[0], + C_API_DTYPE_FLOAT64, + test_set_size, // nrow + n_features, // ncol + 1, // is_row_major + C_API_PREDICT_NORMAL, // predict_type + 0, // start_iteration + -1, // num_iteration + "", + &written, + &mat_output[0]); + EXPECT_EQ(0, result) << "LGBM_BoosterPredictForMat result code: " << result; + + // Now let's run with the single row fast prediction API: + const int kNThreads = 10; + FastConfigHandle fast_configs[kNThreads]; + for (int i = 0; i < kNThreads; i++) { + result = LGBM_BoosterPredictForMatSingleRowFastInit( + booster_handle, + C_API_PREDICT_NORMAL, // predict_type + 0, // start_iteration + -1, // num_iteration + C_API_DTYPE_FLOAT64, + n_features, + "", + &fast_configs[i]); + EXPECT_EQ(0, result) << "LGBM_BoosterPredictForMatSingleRowFastInit result code: " << result; + } + + std::vector single_row_output(output_size * test_set_size, -1); + std::vector threads(kNThreads); + int batch_size = (test_set_size + kNThreads - 1) / kNThreads; // round up + for (int i = 0; i < kNThreads; i++) { + threads[i] = std::thread( + [ + i, batch_size, test_set_size, output_size, n_features, + test = &test[0], fast_configs = &fast_configs[0], single_row_output = &single_row_output[0] + ](){ + int result; + int64_t written; + for (int j = i * batch_size; j < std::min((i + 1) * batch_size, test_set_size); j++) { + result = LGBM_BoosterPredictForMatSingleRowFast( + fast_configs[i], + &test[j * n_features], + &written, + &single_row_output[j * output_size]); + EXPECT_EQ(0, result) << "LGBM_BoosterPredictForMatSingleRowFast result code: " << result; + EXPECT_EQ(written, output_size) << "LGBM_BoosterPredictForMatSingleRowFast unexpected written output size"; + } + }); + } + for (std::thread &t : threads) { + t.join(); + } + + EXPECT_EQ(single_row_output, mat_output) << "LGBM_BoosterPredictForMatSingleRowFast output mismatch with LGBM_BoosterPredictForMat"; + + // Free all: + for (int i = 0; i < kNThreads; i++) { + result = LGBM_FastConfigFree(fast_configs[i]); + EXPECT_EQ(0, result) << "LGBM_FastConfigFree result code: " << result; + } + + result = LGBM_BoosterFree(booster_handle); + EXPECT_EQ(0, result) << "LGBM_BoosterFree result code: " << result; + + result = LGBM_DatasetFree(train_dataset); + EXPECT_EQ(0, result) << "LGBM_DatasetFree result code: " << result; +}