diff --git a/src/common/device_vector.cuh b/src/common/device_vector.cuh index b2065d3330ba..004f0881de3c 100644 --- a/src/common/device_vector.cuh +++ b/src/common/device_vector.cuh @@ -30,6 +30,7 @@ #include // for CurrentDevice #include // for map #include // for unique_ptr +#include // for defer_lock #include "common.h" // for safe_cuda, HumanMemUnit #include "xgboost/logging.h" @@ -46,6 +47,12 @@ class MemoryLogger { size_t num_deallocations{0}; std::map device_allocations; void RegisterAllocation(void *ptr, size_t n) { + auto itr = device_allocations.find(ptr); + if (itr != device_allocations.cend()) { + LOG(WARNING) << "Attempting to allocate " << n << " bytes." + << " that was already allocated\nptr:" << ptr << "\n" + << dmlc::StackTrace(); + } device_allocations[ptr] = n; currently_allocated_bytes += n; peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes); @@ -56,7 +63,7 @@ class MemoryLogger { auto itr = device_allocations.find(ptr); if (itr == device_allocations.end()) { LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device - << " that was never allocated\n" + << " that was never allocated\nptr:" << ptr << "\n" << dmlc::StackTrace(); } else { num_deallocations++; @@ -70,18 +77,34 @@ class MemoryLogger { std::mutex mutex_; public: - void RegisterAllocation(void *ptr, size_t n) { + /** + * @brief Register the allocation for logging. + * + * @param lock Set to false if the allocator has locking machanism. + */ + void RegisterAllocation(void *ptr, size_t n, bool lock) { if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { return; } - std::lock_guard guard(mutex_); + std::unique_lock guard{mutex_, std::defer_lock}; + if (lock) { + guard.lock(); + } stats_.RegisterAllocation(ptr, n); } - void RegisterDeallocation(void *ptr, size_t n) { + /** + * @brief Register the deallocation for logging. + * + * @param lock Set to false if the allocator has locking machanism. + */ + void RegisterDeallocation(void *ptr, size_t n, bool lock) { if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { return; } - std::lock_guard guard(mutex_); + std::unique_lock guard{mutex_, std::defer_lock}; + if (lock) { + guard.lock(); + } stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice()); } size_t PeakMemory() const { return stats_.peak_allocated_bytes; } @@ -140,11 +163,12 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator { } catch (const std::exception &e) { detail::ThrowOOMError(e.what(), n * sizeof(T)); } - GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T)); + // We can't place a lock here as template allocator is transient. + GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T), true); return ptr; } void deallocate(pointer ptr, size_t n) { // NOLINT - GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); + GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true); SuperT::deallocate(ptr, n); } #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 @@ -193,11 +217,12 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator { detail::ThrowOOMError(e.what(), n * sizeof(T)); } } - GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T)); + // We can't place a lock here as template allocator is transient. + GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T), true); return thrust_ptr; } void deallocate(pointer ptr, size_t n) { // NOLINT - GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); + GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true); if (use_cub_allocator_) { GetGlobalCachingAllocator().DeviceFree(ptr.get()); } else { @@ -239,14 +264,15 @@ using caching_device_vector = thrust::device_vector guard{lock_, std::defer_lock}; + if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { + guard.lock(); + } try { auto const ptr = mr_->allocate(bytes, stream); - GlobalMemoryLogger().RegisterAllocation(ptr, bytes); + GlobalMemoryLogger().RegisterAllocation(ptr, bytes, false); return ptr; } catch (rmm::bad_alloc const &e) { detail::ThrowOOMError(e.what(), bytes); @@ -268,8 +298,12 @@ class LoggingResource : public rmm::mr::device_memory_resource { void do_deallocate(void *ptr, std::size_t bytes, // NOLINT rmm::cuda_stream_view stream) override { + std::unique_lock guard{lock_, std::defer_lock}; + if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { + guard.lock(); + } mr_->deallocate(ptr, bytes, stream); - GlobalMemoryLogger().RegisterDeallocation(ptr, bytes); + GlobalMemoryLogger().RegisterDeallocation(ptr, bytes, false); } [[nodiscard]] bool do_is_equal( // NOLINT diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index f81e2116c5df..4b3a3cae644f 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -1,5 +1,5 @@ /** - * Copyright 2018~2023 by XGBoost contributors + * Copyright 2018~2024, XGBoost contributors */ #include #include @@ -32,13 +32,12 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) { double eps = 1.0 / (WQSketch::kFactor * max_bins); size_t dummy_nlevel; size_t num_cuts; - WQuantileSketch::LimitSizeLevel( - num_rows, eps, &dummy_nlevel, &num_cuts); + WQuantileSketch::LimitSizeLevel(num_rows, eps, &dummy_nlevel, &num_cuts); return std::min(num_cuts, num_rows); } -size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, - size_t max_bins, size_t nnz) { +size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins, + bst_idx_t nnz) { auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows); auto if_dense = num_columns * per_column; auto result = std::min(nnz, if_dense); @@ -83,23 +82,31 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz, return peak; } -size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_rows, - bst_feature_t columns, size_t nnz, int device, size_t num_cuts, - bool has_weight) { +bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device, + size_t num_cuts, bool has_weight, std::size_t container_bytes) { auto constexpr kIntMax = static_cast(std::numeric_limits::max()); #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - // device available memory is not accurate when rmm is used. - return std::min(nnz, kIntMax); + // Device available memory is not accurate when rmm is used. + double total_mem = dh::TotalMemory(device) - container_bytes; + double total_f32 = total_mem / sizeof(float); + double n_max_used_f32 = std::max(total_f32 / 16.0, 1.0); // a quarter + if (shape.nnz > shape.Size()) { + // Unknown nnz + shape.nnz = shape.Size(); + } + return std::min(static_cast(n_max_used_f32), shape.nnz); #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + (void)container_bytes; // We known the remaining size when RMM is not used. - if (sketch_batch_num_elements == 0) { - auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight); + if (sketch_batch_num_elements == detail::UnknownSketchNumElements()) { + auto required_memory = + RequiredMemory(shape.n_samples, shape.n_features, shape.nnz, num_cuts, has_weight); // use up to 80% of available space auto avail = dh::AvailableMemory(device) * 0.8; if (required_memory > avail) { sketch_batch_num_elements = avail / BytesPerElement(has_weight); } else { - sketch_batch_num_elements = std::min(num_rows * static_cast(columns), nnz); + sketch_batch_num_elements = std::min(shape.Size(), shape.nnz); } } @@ -338,8 +345,9 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b // Configure batch size based on available memory std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_); sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(), - num_cuts_per_feature, has_weight); + sketch_batch_num_elements, + detail::SketchShape{info.num_row_, info.num_col_, info.num_nonzero_}, ctx->Ordinal(), + num_cuts_per_feature, has_weight, 0); CUDAContext const* cuctx = ctx->CUDACtx(); diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 416a0be9e8f6..47506805353b 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -10,7 +10,10 @@ #include #include // for sort -#include // for size_t +#include // for max +#include // for size_t +#include // for uint32_t +#include // for numeric_limits #include "../data/adapter.h" // for IsValidFunctor #include "algorithm.cuh" // for CopyIf @@ -186,13 +189,24 @@ inline size_t constexpr BytesPerElement(bool has_weight) { return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; } -/* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements` +struct SketchShape { + bst_idx_t n_samples; + bst_feature_t n_features; + bst_idx_t nnz; + + template >* = nullptr> + SketchShape(bst_idx_t n_samples, F n_features, bst_idx_t nnz) + : n_samples{n_samples}, n_features{static_cast(n_features)}, nnz{nnz} {} + + [[nodiscard]] bst_idx_t Size() const { return n_samples * n_features; } +}; + +/** + * @brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements` * directly if it's not 0. */ -size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_idx_t num_rows, bst_feature_t columns, - size_t nnz, int device, - size_t num_cuts, bool has_weight); +bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device, + size_t num_cuts, bool has_weight, std::size_t container_bytes); // Compute number of sample cuts needed on local node to maintain accuracy // We take more cuts than needed and then reduce them later @@ -249,6 +263,8 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info, dh::device_vector* p_sorted_entries, dh::device_vector* p_sorted_weights, dh::caching_device_vector* p_column_sizes_scan); + +constexpr bst_idx_t UnknownSketchNumElements() { return 0; } } // namespace detail /** @@ -264,7 +280,7 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info, */ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, Span hessian, - std::size_t sketch_batch_num_elements = 0); + std::size_t sketch_batch_num_elements = detail::UnknownSketchNumElements()); /** * @brief Compute sketch on DMatrix with GPU. @@ -276,14 +292,15 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b * * @return Quantile cuts */ -inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, - std::size_t sketch_batch_num_elements = 0) { +inline HistogramCuts DeviceSketch( + Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, + std::size_t sketch_batch_num_elements = detail::UnknownSketchNumElements()) { return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements); } template void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info, - size_t columns, size_t begin, size_t end, float missing, + size_t n_features, size_t begin, size_t end, float missing, SketchContainer* sketch_container, int num_cuts) { // Copy current subset of valid elements into temporary storage and sort dh::device_vector sorted_entries; @@ -294,8 +311,9 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf HostDeviceVector cuts_ptr; cuts_ptr.SetDevice(ctx->Device()); CUDAContext const* cuctx = ctx->CUDACtx(); - detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, num_cuts, - ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries); + detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, n_features, + num_cuts, ctx->Device(), &cuts_ptr, &column_sizes_scan, + &sorted_entries); thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); if (sketch_container->HasCategorical()) { @@ -305,10 +323,11 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf } auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - auto const &h_cuts_ptr = cuts_ptr.HostVector(); + auto const& h_cuts_ptr = cuts_ptr.HostVector(); // Extract the cuts from all columns concurrently sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, h_cuts_ptr.back()); + sorted_entries.clear(); sorted_entries.shrink_to_fit(); } @@ -316,10 +335,10 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf template void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info, int num_cuts_per_feature, bool is_ranking, float missing, - DeviceOrd device, size_t columns, size_t begin, size_t end, + size_t columns, size_t begin, size_t end, SketchContainer* sketch_container) { - dh::safe_cuda(cudaSetDevice(device.ordinal)); - info.weights_.SetDevice(device); + SetDevice(ctx->Ordinal()); + info.weights_.SetDevice(ctx->Device()); auto weights = info.weights_.ConstDeviceSpan(); auto batch_iter = dh::MakeTransformIterator( @@ -330,7 +349,7 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons dh::caching_device_vector column_sizes_scan; HostDeviceVector cuts_ptr; detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, - num_cuts_per_feature, device, &cuts_ptr, &column_sizes_scan, + num_cuts_per_feature, ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries); data::IsValidFunctor is_valid(missing); @@ -388,48 +407,59 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons sorted_entries.shrink_to_fit(); } -/* - * \brief Perform sketching on GPU. +/** + * @brief Perform sketching on GPU. * - * \param batch A batch from adapter. - * \param num_bins Bins per column. - * \param info Metainfo used for sketching. - * \param missing Floating point value that represents invalid value. - * \param sketch_container Container for output sketch. - * \param sketch_batch_num_elements Number of element per-sliding window, use it only for + * @param batch A batch from adapter. + * @param num_bins Bins per column. + * @param info Metainfo used for sketching. + * @param missing Floating point value that represents invalid value. + * @param sketch_container Container for output sketch. + * @param sketch_batch_num_elements Number of element per-sliding window, use it only for * testing. */ template -void AdapterDeviceSketch(Context const* ctx, Batch batch, int num_bins, MetaInfo const& info, +void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, MetaInfo const& info, float missing, SketchContainer* sketch_container, - size_t sketch_batch_num_elements = 0) { - size_t num_rows = batch.NumRows(); + bst_idx_t sketch_batch_num_elements = detail::UnknownSketchNumElements()) { + bst_idx_t num_rows = batch.NumRows(); size_t num_cols = batch.NumCols(); - size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); - auto device = sketch_container->DeviceIdx(); + bool weighted = !info.weights_.Empty(); - if (weighted) { + bst_idx_t const kRemaining = batch.Size(); + bst_idx_t begin = 0; + + auto shape = detail::SketchShape{num_rows, num_cols, std::numeric_limits::max()}; + + while (begin < kRemaining) { + // Use total number of samples to estimate the needed cuts first, this doesn't hurt + // accuracy as total number of samples is larger. + auto num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); + // Estimate the memory usage based on the current available memory. sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits::max(), - device.ordinal, num_cuts_per_feature, true); - for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { - size_t end = - std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); + sketch_batch_num_elements, shape, ctx->Ordinal(), num_cuts_per_feature, weighted, + sketch_container->MemCostBytes()); + // Re-estimate the needed number of cuts based on the size of the sub-batch. + // + // The estimation of `sketch_batch_num_elements` assumes dense input, so the + // approximation here is reasonably accurate. It doesn't hurt accuracy since the + // estimated n_samples must be greater or equal to the actual n_samples thanks to the + // dense assumption. + auto approx_n_samples = std::max(sketch_batch_num_elements / num_cols, bst_idx_t{1}); + num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, approx_n_samples); + bst_idx_t end = + std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); + + if (weighted) { ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature, - HostSketchContainer::UseGroup(info), missing, device, num_cols, - begin, end, sketch_container); - } - } else { - sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits::max(), - device.ordinal, num_cuts_per_feature, false); - for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { - size_t end = - std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); + HostSketchContainer::UseGroup(info), missing, num_cols, begin, + end, sketch_container); + } else { ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container, num_cuts_per_feature); } + begin += sketch_batch_num_elements; } } } // namespace xgboost::common diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 295206f0aa34..f2c7e44619c4 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -309,7 +309,7 @@ void MergeImpl(Context const *ctx, Span const &d_x, void SketchContainer::Push(Context const *ctx, Span entries, Span columns_ptr, common::Span cuts_ptr, size_t total_cuts, Span weights) { - common::SetDevice(device_.ordinal); + common::SetDevice(ctx->Ordinal()); Span out; dh::device_vector cuts; bool first_window = this->Current().empty(); @@ -354,7 +354,7 @@ void SketchContainer::Push(Context const *ctx, Span entries, SpanFixError(); } else { this->Current().resize(n_uniques); - this->columns_ptr_.SetDevice(device_); + this->columns_ptr_.SetDevice(ctx->Device()); this->columns_ptr_.Resize(cuts_ptr.size()); auto d_cuts_ptr = this->columns_ptr_.DeviceSpan(); @@ -369,7 +369,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span entries, * pruning or merging. We preserve the first type and remove the second type. */ timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_.ordinal)); + SetDevice(ctx->Ordinal()); CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1); auto key_it = dh::MakeTransformIterator( @@ -408,7 +408,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span entries, void SketchContainer::Prune(Context const* ctx, std::size_t to) { timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_.ordinal)); + SetDevice(ctx->Ordinal()); OffsetT to_total = 0; auto& h_columns_ptr = columns_ptr_b_.HostVector(); @@ -443,7 +443,12 @@ void SketchContainer::Prune(Context const* ctx, std::size_t to) { void SketchContainer::Merge(Context const *ctx, Span d_that_columns_ptr, Span that) { - common::SetDevice(device_.ordinal); + SetDevice(ctx->Ordinal()); + auto self = dh::ToSpan(this->Current()); + LOG(DEBUG) << "Merge: self:" << HumanMemUnit(self.size_bytes()) << ". " + << "That:" << HumanMemUnit(that.size_bytes()) << ". " + << "This capacity:" << HumanMemUnit(this->MemCapacityBytes()) << "." << std::endl; + timer_.Start(__func__); if (this->Current().size() == 0) { CHECK_EQ(this->columns_ptr_.HostVector().back(), 0); @@ -478,7 +483,6 @@ void SketchContainer::Merge(Context const *ctx, Span d_that_colum } void SketchContainer::FixError() { - dh::safe_cuda(cudaSetDevice(device_.ordinal)); auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); auto in = dh::ToSpan(this->Current()); dh::LaunchN(in.size(), [=] __device__(size_t idx) { @@ -503,7 +507,7 @@ void SketchContainer::FixError() { } void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { - dh::safe_cuda(cudaSetDevice(device_.ordinal)); + SetDevice(ctx->Ordinal()); auto world = collective::GetWorldSize(); if (world == 1 || is_column_split) { return; @@ -541,7 +545,7 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { std::vector recv_lengths; HostDeviceVector recvbuf; rc = collective::AllgatherV( - ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_), + ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), ctx->Device()), &recv_lengths, &recvbuf); collective::SafeColl(rc); for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) { @@ -563,9 +567,8 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { } // Merge them into a new sketch. - SketchContainer new_sketch(this->feature_types_, num_bins_, - this->num_columns_, global_sum_rows, - this->device_); + SketchContainer new_sketch(this->feature_types_, num_bins_, this->num_columns_, global_sum_rows, + ctx->Device()); for (size_t i = 0; i < allworkers.size(); ++i) { auto worker = allworkers[i]; auto worker_ptr = @@ -593,7 +596,7 @@ struct InvalidCatOp { void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) { timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_.ordinal)); + SetDevice(ctx->Ordinal()); p_cuts->min_vals_.Resize(num_columns_); // Sync between workers. @@ -606,12 +609,12 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i // Set up inputs auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); - p_cuts->min_vals_.SetDevice(device_); + p_cuts->min_vals_.SetDevice(ctx->Device()); auto d_min_values = p_cuts->min_vals_.DeviceSpan(); auto const in_cut_values = dh::ToSpan(this->Current()); // Set up output ptr - p_cuts->cut_ptrs_.SetDevice(device_); + p_cuts->cut_ptrs_.SetDevice(ctx->Device()); auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector(); h_out_columns_ptr.clear(); h_out_columns_ptr.push_back(0); @@ -689,7 +692,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan(); size_t total_bins = h_out_columns_ptr.back(); - p_cuts->cut_values_.SetDevice(device_); + p_cuts->cut_values_.SetDevice(ctx->Device()); p_cuts->cut_values_.Resize(total_bins); auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 239388b3b62c..4d849540af9f 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -8,6 +8,7 @@ #include "categorical.h" #include "cuda_context.cuh" // for CUDAContext +#include "cuda_rt_utils.h" // for SetDevice #include "device_helpers.cuh" #include "error_msg.h" // for InvalidMaxBin #include "quantile.h" @@ -15,9 +16,7 @@ #include "xgboost/data.h" #include "xgboost/span.h" -namespace xgboost { -namespace common { - +namespace xgboost::common { class HistogramCuts; using WQSketch = WQuantileSketch; using SketchEntry = WQSketch::Entry; @@ -46,7 +45,6 @@ class SketchContainer { bst_idx_t num_rows_; bst_feature_t num_columns_; int32_t num_bins_; - DeviceOrd device_; // Double buffer as neither prune nor merge can be performed inplace. dh::device_vector entries_a_; @@ -100,12 +98,12 @@ class SketchContainer { */ SketchContainer(HostDeviceVector const& feature_types, bst_bin_t max_bin, bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device) - : num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { + : num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin} { CHECK(device.IsCUDA()); // Initialize Sketches for this dmatrix - this->columns_ptr_.SetDevice(device_); + this->columns_ptr_.SetDevice(device); this->columns_ptr_.Resize(num_columns + 1, 0); - this->columns_ptr_b_.SetDevice(device_); + this->columns_ptr_b_.SetDevice(device); this->columns_ptr_b_.Resize(num_columns + 1, 0); this->feature_types_.Resize(feature_types.Size()); @@ -123,8 +121,25 @@ class SketchContainer { timer_.Init(__func__); } - /* \brief Return GPU ID for this container. */ - [[nodiscard]] DeviceOrd DeviceIdx() const { return device_; } + /** + * @brief Calculate the memory cost of the container. + */ + [[nodiscard]] std::size_t MemCapacityBytes() const { + auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type); + auto n_bytes = (this->entries_a_.capacity() + this->entries_b_.capacity()) * kE; + n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT); + n_bytes += this->feature_types_.Size() * sizeof(FeatureType); + + return n_bytes; + } + [[nodiscard]] std::size_t MemCostBytes() const { + auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type); + auto n_bytes = (this->entries_a_.size() + this->entries_b_.size()) * kE; + n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT); + n_bytes += this->feature_types_.Size() * sizeof(FeatureType); + + return n_bytes; + } /* \brief Whether the predictor matrix contains categorical features. */ bool HasCategorical() const { return has_categorical_; } /* \brief Accumulate weights of duplicated entries in input. */ @@ -166,6 +181,7 @@ class SketchContainer { this->Current().shrink_to_fit(); this->Other().clear(); this->Other().shrink_to_fit(); + LOG(DEBUG) << "Quantile memory cost:" << this->MemCapacityBytes(); } /* \brief Merge quantiles from other GPU workers. */ @@ -190,13 +206,13 @@ class SketchContainer { template > size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to{}) { timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_.ordinal)); - this->columns_ptr_.SetDevice(device_); + SetDevice(ctx->Ordinal()); + this->columns_ptr_.SetDevice(ctx->Device()); Span d_column_scan = this->columns_ptr_.DeviceSpan(); CHECK_EQ(d_column_scan.size(), num_columns_ + 1); Span entries = dh::ToSpan(this->Current()); HostDeviceVector scan_out(d_column_scan.size()); - scan_out.SetDevice(device_); + scan_out.SetDevice(ctx->Device()); auto d_scan_out = scan_out.DeviceSpan(); d_column_scan = this->columns_ptr_.DeviceSpan(); @@ -212,7 +228,6 @@ class SketchContainer { return n_uniques; } }; -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_QUANTILE_CUH_ diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index f981a181b89f..508a0e0b1b91 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -65,7 +65,9 @@ TEST(HistUtil, SketchBatchNumElements) { auto per_elem = detail::BytesPerElement(false); auto avail_elem = avail / per_elem; size_t rows = avail_elem / kCols * 10; - auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false); + auto shape = detail::SketchShape{rows, kCols, rows * kCols}; + auto batch = detail::SketchBatchNumElements(detail::UnknownSketchNumElements(), shape, device, + 256, false, 0); ASSERT_EQ(batch, avail_elem); }