Skip to content

Commit

Permalink
Remove internal use of gpu_id. (#9568)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Sep 20, 2023
1 parent 38ac52d commit 8c676c8
Show file tree
Hide file tree
Showing 121 changed files with 1,014 additions and 1,046 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ jobs:
run: |
mkdir build_msvc
cd build_msvc
cmake .. -G"Visual Studio 17 2022" -DCMAKE_CONFIGURATION_TYPES="Release" -A x64 -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DBUILD_DEPRECATED_CLI=ON
cmake .. -G"Visual Studio 17 2022" -DCMAKE_CONFIGURATION_TYPES="Release" -A x64 -DBUILD_DEPRECATED_CLI=ON
cmake --build . --config Release --parallel $(nproc)
- name: Install Python package
Expand Down
41 changes: 23 additions & 18 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,37 @@ struct DeviceSym {
* viewing types like `linalg::TensorView`.
*/
struct DeviceOrd {
// Constant representing the device ID of CPU.
static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; }
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; }

enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
// CUDA device ordinal.
bst_d_ordinal_t ordinal{-1};
bst_d_ordinal_t ordinal{CPUOrdinal()};

[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
[[nodiscard]] bool IsCPU() const { return device == kCPU; }

DeviceOrd() = default;
constexpr DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}

DeviceOrd(DeviceOrd const& that) = default;
DeviceOrd& operator=(DeviceOrd const& that) = default;
DeviceOrd(DeviceOrd&& that) = default;
DeviceOrd& operator=(DeviceOrd&& that) = default;
constexpr DeviceOrd(DeviceOrd const& that) = default;
constexpr DeviceOrd& operator=(DeviceOrd const& that) = default;
constexpr DeviceOrd(DeviceOrd&& that) = default;
constexpr DeviceOrd& operator=(DeviceOrd&& that) = default;

/**
* @brief Constructor for CPU.
*/
[[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, -1}; }
[[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, CPUOrdinal()}; }
/**
* @brief Constructor for CUDA device.
*
* @param ordinal CUDA device ordinal.
*/
[[nodiscard]] static auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; }
[[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) {
return DeviceOrd{kCUDA, ordinal};
}

[[nodiscard]] bool operator==(DeviceOrd const& that) const {
return device == that.device && ordinal == that.ordinal;
Expand All @@ -78,33 +84,33 @@ struct DeviceOrd {

static_assert(sizeof(DeviceOrd) == sizeof(std::int32_t));

std::ostream& operator<<(std::ostream& os, DeviceOrd ord);

/**
* @brief Runtime context for XGBoost. Contains information like threads and device.
*/
struct Context : public XGBoostParameter<Context> {
private:
// User interfacing parameter for device ordinal
std::string device{DeviceSym::CPU()}; // NOLINT
// The device object for the current context. We are in the middle of replacing the
// `gpu_id` with this device field.
// The device ordinal set by user
DeviceOrd device_{DeviceOrd::CPU()};

public:
// Constant representing the device ID of CPU.
static bst_d_ordinal_t constexpr kCpuId = -1;
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; }
static std::int64_t constexpr kDefaultSeed = 0;

public:
Context();

void Init(Args const& kwargs);

template <typename Container>
Args UpdateAllowUnknown(Container const& kwargs) {
auto args = XGBoostParameter<Context>::UpdateAllowUnknown(kwargs);
this->SetDeviceOrdinal(kwargs);
return args;
}

std::int32_t gpu_id{kCpuId};
// The number of threads to use if OpenMP is enabled. If equals 0, use the system default.
std::int32_t nthread{0}; // NOLINT
// stored random seed
Expand All @@ -116,7 +122,8 @@ struct Context : public XGBoostParameter<Context> {
bool validate_parameters{false};

/**
* @brief Configure the parameter `gpu_id'.
* @brief Configure the parameter `device'. Deprecated, will remove once `gpu_id` is
* removed.
*
* @param require_gpu Whether GPU is explicitly required by the user through other
* configurations.
Expand Down Expand Up @@ -212,9 +219,7 @@ struct Context : public XGBoostParameter<Context> {
private:
void SetDeviceOrdinal(Args const& kwargs);
Context& SetDevice(DeviceOrd d) {
this->device_ = d;
this->gpu_id = d.ordinal; // this can be removed once we move away from `gpu_id`.
this->device = d.Name();
this->device = (this->device_ = d).Name();
return *this;
}

Expand Down
6 changes: 3 additions & 3 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class MetaInfo {
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
/**
* @brief Validate all metainfo.
*/
void Validate(int32_t device) const;
void Validate(DeviceOrd device) const;

MetaInfo Slice(common::Span<int32_t const> ridxs) const;

Expand Down
31 changes: 11 additions & 20 deletions include/xgboost/host_device_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class HostDeviceVector {
static_assert(std::is_standard_layout<T>::value, "HostDeviceVector admits only POD types");

public:
explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1);
HostDeviceVector(std::initializer_list<T> init, int device = -1);
explicit HostDeviceVector(const std::vector<T>& init, int device = -1);
explicit HostDeviceVector(size_t size = 0, T v = T(), DeviceOrd device = DeviceOrd::CPU());
HostDeviceVector(std::initializer_list<T> init, DeviceOrd device = DeviceOrd::CPU());
explicit HostDeviceVector(const std::vector<T>& init, DeviceOrd device = DeviceOrd::CPU());
~HostDeviceVector();

HostDeviceVector(const HostDeviceVector<T>&) = delete;
Expand All @@ -99,17 +99,9 @@ class HostDeviceVector {
HostDeviceVector<T>& operator=(const HostDeviceVector<T>&) = delete;
HostDeviceVector<T>& operator=(HostDeviceVector<T>&&);

bool Empty() const { return Size() == 0; }
size_t Size() const;
int DeviceIdx() const;
DeviceOrd Device() const {
auto idx = this->DeviceIdx();
if (idx == DeviceOrd::CPU().ordinal) {
return DeviceOrd::CPU();
} else {
return DeviceOrd::CUDA(idx);
}
}
[[nodiscard]] bool Empty() const { return Size() == 0; }
[[nodiscard]] std::size_t Size() const;
[[nodiscard]] DeviceOrd Device() const;
common::Span<T> DeviceSpan();
common::Span<const T> ConstDeviceSpan() const;
common::Span<const T> DeviceSpan() const { return ConstDeviceSpan(); }
Expand All @@ -135,13 +127,12 @@ class HostDeviceVector {
const std::vector<T>& ConstHostVector() const;
const std::vector<T>& HostVector() const {return ConstHostVector(); }

bool HostCanRead() const;
bool HostCanWrite() const;
bool DeviceCanRead() const;
bool DeviceCanWrite() const;
GPUAccess DeviceAccess() const;
[[nodiscard]] bool HostCanRead() const;
[[nodiscard]] bool HostCanWrite() const;
[[nodiscard]] bool DeviceCanRead() const;
[[nodiscard]] bool DeviceCanWrite() const;
[[nodiscard]] GPUAccess DeviceAccess() const;

void SetDevice(int device) const;
void SetDevice(DeviceOrd device) const;

void Resize(size_t new_size, T v = T());
Expand Down
39 changes: 15 additions & 24 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -659,13 +659,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {

template <typename T>
auto MakeVec(HostDeviceVector<T> *data) {
return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(),
data->Size(), data->Device());
return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(), data->Size(),
data->Device());
}

template <typename T>
auto MakeVec(HostDeviceVector<T> const *data) {
return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(),
return MakeVec(data->Device().IsCPU() ? data->ConstHostPointer() : data->ConstDevicePointer(),
data->Size(), data->Device());
}

Expand Down Expand Up @@ -757,13 +757,13 @@ class Tensor {
Order order_{Order::kC};

template <typename I, std::int32_t D>
void Initialize(I const (&shape)[D], std::int32_t device) {
void Initialize(I const (&shape)[D], DeviceOrd device) {
static_assert(D <= kDim, "Invalid shape.");
std::copy(shape, shape + D, shape_);
for (auto i = D; i < kDim; ++i) {
shape_[i] = 1;
}
if (device >= 0) {
if (device.IsCUDA()) {
data_.SetDevice(device);
data_.ConstDevicePointer(); // Pull to device;
}
Expand All @@ -780,34 +780,31 @@ class Tensor {
* See \ref TensorView for parameters of this constructor.
*/
template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC)
: Tensor{common::Span<I const, D>{shape}, device, order} {}
template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], DeviceOrd device, Order order = kC)
: Tensor{common::Span<I const, D>{shape}, device.ordinal, order} {}
: Tensor{common::Span<I const, D>{shape}, device, order} {}

template <typename I, size_t D>
explicit Tensor(common::Span<I const, D> shape, std::int32_t device, Order order = kC)
explicit Tensor(common::Span<I const, D> shape, DeviceOrd device, Order order = kC)
: order_{order} {
// No device unroll as this is a host only function.
std::copy(shape.data(), shape.data() + D, shape_);
for (auto i = D; i < kDim; ++i) {
shape_[i] = 1;
}
auto size = detail::CalcSize(shape_);
if (device >= 0) {
if (device.IsCUDA()) {
data_.SetDevice(device);
}
data_.Resize(size);
if (device >= 0) {
if (device.IsCUDA()) {
data_.DevicePointer(); // Pull to device
}
}
/**
* Initialize from 2 host iterators.
*/
template <typename It, typename I, int32_t D>
explicit Tensor(It begin, It end, I const (&shape)[D], std::int32_t device, Order order = kC)
explicit Tensor(It begin, It end, I const (&shape)[D], DeviceOrd device, Order order = kC)
: order_{order} {
auto &h_vec = data_.HostVector();
h_vec.insert(h_vec.begin(), begin, end);
Expand All @@ -816,18 +813,14 @@ class Tensor {
}

template <typename I, int32_t D>
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], std::int32_t device,
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], DeviceOrd device,
Order order = kC)
: order_{order} {
auto &h_vec = data_.HostVector();
h_vec = data;
// shape
this->Initialize(shape, device);
}
template <typename I, int32_t D>
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], DeviceOrd device,
Order order = kC)
: Tensor{data, shape, device.ordinal, order} {}
/**
* \brief Index operator. Not thread safe, should not be used in performance critical
* region. For more efficient indexing, consider getting a view first.
Expand Down Expand Up @@ -944,9 +937,7 @@ class Tensor {
/**
* \brief Set device ordinal for this tensor.
*/
void SetDevice(int32_t device) const { data_.SetDevice(device); }
void SetDevice(DeviceOrd device) const { data_.SetDevice(device); }
[[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
[[nodiscard]] DeviceOrd Device() const { return data_.Device(); }
};

Expand All @@ -962,7 +953,7 @@ using Vector = Tensor<T, 1>;
template <typename T, typename... Index>
auto Empty(Context const *ctx, Index &&...index) {
Tensor<T, sizeof...(Index)> t;
t.SetDevice(ctx->gpu_id);
t.SetDevice(ctx->Device());
t.Reshape(index...);
return t;
}
Expand All @@ -973,7 +964,7 @@ auto Empty(Context const *ctx, Index &&...index) {
template <typename T, typename... Index>
auto Constant(Context const *ctx, T v, Index &&...index) {
Tensor<T, sizeof...(Index)> t;
t.SetDevice(ctx->gpu_id);
t.SetDevice(ctx->Device());
t.Reshape(index...);
t.Data()->Fill(std::move(v));
return t;
Expand All @@ -990,8 +981,8 @@ auto Zeros(Context const *ctx, Index &&...index) {
// Only first axis is supported for now.
template <typename T, int32_t D>
void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {
if (r.DeviceIdx() >= 0) {
l->SetDevice(r.DeviceIdx());
if (r.Device().IsCUDA()) {
l->SetDevice(r.Device());
}
l->ModifyInplace([&](HostDeviceVector<T> *data, common::Span<size_t, D> shape) {
for (size_t i = 1; i < D; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {

public:
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, std::int32_t device) {
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, DeviceOrd device) {
auto p_cache = this->CacheItem(m);
if (device != Context::kCpuId) {
if (device.IsCUDA()) {
p_cache->predictions.SetDevice(device);
}
return *p_cache;
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con
auto hess_dev = dh::CudaGetPointerDevice(hess.data);
CHECK_EQ(grad_dev, hess_dev) << "gradient and hessian should be on the same device.";
auto &gpair = *out_gpair;
gpair.SetDevice(grad_dev);
gpair.SetDevice(DeviceOrd::CUDA(grad_dev));
gpair.Reshape(grad.Shape(0), grad.Shape(1));
auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev));
auto cuctx = ctx->CUDACtx();
Expand Down Expand Up @@ -144,7 +144,7 @@ int InplacePreidctCUDA(BoosterHandle handle, char const *c_array_interface,
if (learner->Ctx()->IsCUDA()) {
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
}
p_predt->SetDevice(proxy->DeviceIdx());
p_predt->SetDevice(proxy->Device());

auto &shape = learner->GetThreadLocal().prediction_shape;
size_t n_samples = p_m->Info().num_row_;
Expand Down
10 changes: 4 additions & 6 deletions src/collective/aggregator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

#include "communicator-inl.cuh"

namespace xgboost {
namespace collective {
namespace xgboost::collective {

/**
* @brief Find the global sum of the given values across all workers.
Expand All @@ -31,10 +30,9 @@ namespace collective {
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) {
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
if (info.IsRowSplit()) {
collective::AllReduce<collective::Operation::kSum>(device, values, size);
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective
Loading

0 comments on commit 8c676c8

Please sign in to comment.