Skip to content

Commit

Permalink
[breaking] Remove the predictor param, allow fallback to prediction…
Browse files Browse the repository at this point in the history
… using `DMatrix`. (#9129)


- A `DeviceOrd` struct is implemented to indicate the device. It will eventually replace the `gpu_id` parameter.
- The `predictor` parameter is removed.
- Fallback to `DMatrix` when `inplace_predict` is not available.
- The heuristic for choosing a predictor is only used during training.
  • Loading branch information
trivialfis authored Jul 3, 2023
1 parent 3a0f787 commit 39390cc
Show file tree
Hide file tree
Showing 54 changed files with 1,049 additions and 778 deletions.
2 changes: 1 addition & 1 deletion doc/gpu/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as

.. code-block:: python
model.set_param({"predictor": "gpu_predictor"})
model.set_param({"gpu_id": "0", "tree_method": "gpu_hist"})
shap_values = model.predict(dtrain, pred_contribs=True)
shap_interaction_values = model.predict(dtrain, pred_interactions=True)
Expand Down
12 changes: 0 additions & 12 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,6 @@ Parameters for Tree Booster
- Maximum number of discrete bins to bucket continuous features.
- Increasing this number improves the optimality of splits at the cost of higher computation time.

* ``predictor``, [default= ``auto``]

- The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.

- ``auto``: Configure predictor based on heuristics.
- ``cpu_predictor``: Multicore CPU prediction algorithm.
- ``gpu_predictor``: Prediction using GPU. Used when ``tree_method`` is ``gpu_hist``.
When ``predictor`` is set to default value ``auto``, the ``gpu_hist`` tree method is
able to provide GPU based prediction without copying training data to GPU memory.
If ``gpu_predictor`` is explicitly specified, then all data is copied into GPU, only
recommended for performing prediction tasks.

* ``num_parallel_tree``, [default=1]

- Number of parallel trees constructed during each iteration. This option is used to support boosted random forest.
Expand Down
19 changes: 0 additions & 19 deletions doc/prediction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@ with the native Python interface :py:meth:`xgboost.Booster.predict` and
behavior. Also the ``save_best`` parameter from :py:obj:`xgboost.callback.EarlyStopping`
might be useful.

*********
Predictor
*********

There are 2 predictors in XGBoost (3 if you have the one-api plugin enabled), namely
``cpu_predictor`` and ``gpu_predictor``. The default option is ``auto`` so that XGBoost
can employ some heuristics for saving GPU memory during training. They might have slight
different outputs due to floating point errors.


***********
Base Margin
Expand Down Expand Up @@ -134,15 +125,6 @@ it. Be aware that the output of in-place prediction depends on input data type,
input is on GPU data output is :py:obj:`cupy.ndarray`, otherwise a :py:obj:`numpy.ndarray`
is returned.

****************
Categorical Data
****************

Other than users performing encoding, XGBoost has experimental support for categorical
data using ``gpu_hist`` and ``gpu_predictor``. No special operation needs to be done on
input test data since the information about categories is encoded into the model during
training.

*************
Thread Safety
*************
Expand All @@ -159,7 +141,6 @@ instance we might accidentally call ``clf.set_params()`` inside a predict functi
def predict_fn(clf: xgb.XGBClassifier, X):
X = preprocess(X)
clf.set_params(predictor="gpu_predictor") # NOT safe!
clf.set_params(n_jobs=1) # NOT safe!
return clf.predict_proba(X, iteration_range=(0, 10))
Expand Down
4 changes: 2 additions & 2 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ Also for inplace prediction:

.. code-block:: python
booster.set_param({'predictor': 'gpu_predictor'})
# where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
booster.set_param({"gpu_id": "0"})
prediction = xgb.dask.inplace_predict(client, booster, X)
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
Expand Down
1 change: 0 additions & 1 deletion doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ Will print out something similar to (not actual output as it's too long for demo
"gradient_booster": {
"gbtree_train_param": {
"num_parallel_tree": "1",
"predictor": "gpu_predictor",
"process_type": "default",
"tree_method": "gpu_hist",
"updater": "grow_gpu_hist",
Expand Down
7 changes: 6 additions & 1 deletion include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <dmlc/omp.h>

#include <cmath>
#include <cstdint>
#include <iostream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -112,7 +113,7 @@ using bst_row_t = std::size_t; // NOLINT
/*! \brief Type for tree node index. */
using bst_node_t = std::int32_t; // NOLINT
/*! \brief Type for ranking group index. */
using bst_group_t = std::uint32_t; // NOLINT
using bst_group_t = std::uint32_t; // NOLINT
/**
* \brief Type for indexing into output targets.
*/
Expand All @@ -125,6 +126,10 @@ using bst_layer_t = std::int32_t; // NOLINT
* \brief Type for indexing trees.
*/
using bst_tree_t = std::int32_t; // NOLINT
/**
* @brief Ordinal of a CUDA device.
*/
using bst_d_ordinal_t = std::int16_t; // NOLINT

namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation
Expand Down
12 changes: 12 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,9 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat
/**
* \brief Inplace prediction from CPU dense matrix.
*
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle.
* \param values JSON encoded __array_interface__ to values.
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
Expand All @@ -1091,6 +1094,9 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
/**
* \brief Inplace prediction from CPU CSR matrix.
*
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle.
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
* \param indices JSON encoded __array_interface__ to column indices in CSR.
Expand All @@ -1116,6 +1122,9 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
/**
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
*
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle
* \param values JSON encoded __cuda_array_interface__ to values.
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
Expand All @@ -1137,6 +1146,9 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *valu
/**
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
*
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
Expand Down
139 changes: 114 additions & 25 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,79 @@
/*!
* Copyright 2014-2022 by Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file context.h
*/
#ifndef XGBOOST_CONTEXT_H_
#define XGBOOST_CONTEXT_H_

#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <xgboost/base.h> // for bst_d_ordinal_t
#include <xgboost/logging.h> // for CHECK_GE
#include <xgboost/parameter.h> // for XGBoostParameter

#include <memory> // std::shared_ptr
#include <string>
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string

namespace xgboost {

struct CUDAContext;

/**
* @brief A type for device ordinal. The type is packed into 32-bit for efficient use in
* viewing types like `linalg::TensorView`.
*/
struct DeviceOrd {
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
// CUDA device ordinal.
bst_d_ordinal_t ordinal{-1};

[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
[[nodiscard]] bool IsCPU() const { return device == kCPU; }

DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}

DeviceOrd(DeviceOrd const& that) = default;
DeviceOrd& operator=(DeviceOrd const& that) = default;
DeviceOrd(DeviceOrd&& that) = default;
DeviceOrd& operator=(DeviceOrd&& that) = default;

/**
* @brief Constructor for CPU.
*/
[[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, -1}; }
/**
* @brief Constructor for CUDA device.
*
* @param ordinal CUDA device ordinal.
*/
[[nodiscard]] static auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; }

[[nodiscard]] bool operator==(DeviceOrd const& that) const {
return device == that.device && ordinal == that.ordinal;
}
[[nodiscard]] bool operator!=(DeviceOrd const& that) const { return !(*this == that); }
/**
* @brief Get a string representation of the device and the ordinal.
*/
[[nodiscard]] std::string Name() const {
switch (device) {
case DeviceOrd::kCPU:
return "CPU";
case DeviceOrd::kCUDA:
return "CUDA:" + std::to_string(ordinal);
default: {
LOG(FATAL) << "Unknown device.";
return "";
}
}
}
};

static_assert(sizeof(DeviceOrd) == sizeof(std::int32_t));

/**
* @brief Runtime context for XGBoost. Contains information like threads and device.
*/
struct Context : public XGBoostParameter<Context> {
public:
// Constant representing the device ID of CPU.
Expand All @@ -36,29 +95,59 @@ struct Context : public XGBoostParameter<Context> {
// fail when gpu_id is invalid
bool fail_on_invalid_gpu_id{false};
bool validate_parameters{false};

/*!
* \brief Configure the parameter `gpu_id'.
/**
* @brief Configure the parameter `gpu_id'.
*
* \param require_gpu Whether GPU is explicitly required from user.
* @param require_gpu Whether GPU is explicitly required by the user through other
* configurations.
*/
void ConfigureGpuId(bool require_gpu);
/*!
* Return automatically chosen threads.
/**
* @brief Returns the automatically chosen number of threads based on the `nthread`
* parameter and the system settting.
*/
std::int32_t Threads() const;

bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }

CUDAContext const* CUDACtx() const;
// Make a CUDA context based on the current context.
Context MakeCUDA(std::int32_t device = 0) const {
[[nodiscard]] std::int32_t Threads() const;
/**
* @brief Is XGBoost running on CPU?
*/
[[nodiscard]] bool IsCPU() const { return gpu_id == kCpuId; }
/**
* @brief Is XGBoost running on a CUDA device?
*/
[[nodiscard]] bool IsCUDA() const { return !IsCPU(); }
/**
* @brief Get the current device and ordinal.
*/
[[nodiscard]] DeviceOrd Device() const {
return IsCPU() ? DeviceOrd::CPU() : DeviceOrd::CUDA(static_cast<bst_d_ordinal_t>(gpu_id));
}
/**
* @brief Get the CUDA device ordinal. -1 if XGBoost is running on CPU.
*/
[[nodiscard]] bst_d_ordinal_t Ordinal() const { return this->gpu_id; }
/**
* @brief Name of the current device.
*/
[[nodiscard]] std::string DeviceName() const { return Device().Name(); }
/**
* @brief Get a CUDA device context for allocator and stream.
*/
[[nodiscard]] CUDAContext const* CUDACtx() const;
/**
* @brief Make a CUDA context based on the current context.
*
* @param ordinal The CUDA device ordinal.
*/
[[nodiscard]] Context MakeCUDA(std::int32_t ordinal = 0) const {
Context ctx = *this;
ctx.gpu_id = device;
CHECK_GE(ordinal, 0);
ctx.gpu_id = ordinal;
return ctx;
}
Context MakeCPU() const {
/**
* @brief Make a CPU context based on the current context.
*/
[[nodiscard]] Context MakeCPU() const {
Context ctx = *this;
ctx.gpu_id = kCpuId;
return ctx;
Expand Down Expand Up @@ -87,9 +176,9 @@ struct Context : public XGBoostParameter<Context> {
}

private:
// mutable for lazy initialization for cuda context to avoid initializing CUDA at load.
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl
// while trying to hide CUDA code from host compiler.
// mutable for lazy cuda context initialization. This avoids initializing CUDA at load.
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define
// p_impl while trying to hide CUDA code from the host compiler.
mutable std::shared_ptr<CUDAContext> cuctx_;
// cached value for CFS CPU limit. (used in containerized env)
std::int32_t cfs_cpu_count_; // NOLINT
Expand Down
18 changes: 7 additions & 11 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,14 @@ class GradientBooster : public Model, public Configurable {
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param approximate use a faster (inconsistent) approximation of SHAP values
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature feature to condition on (i.e. fix) during calculations
*/
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned layer_begin, unsigned layer_end,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) = 0;

virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate = false) = 0;

virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) = 0;

/*!
* \brief dump the model in the requested format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ public void testBooster() throws XGBoostError {
put("num_round", round);
put("num_workers", 1);
put("tree_method", "gpu_hist");
put("predictor", "gpu_predictor");
put("max_bin", maxBin);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
booster.setParam("gpu_id", gpuId.toString)
booster.setParam("predictor", "gpu_predictor")
logger.info("GPU transform on device: " + gpuId)
boosterFlag.isGpuParamsSet = true;
}
Expand Down
Loading

0 comments on commit 39390cc

Please sign in to comment.