Skip to content

Commit

Permalink
[EM] Improve memory estimation for quantile sketching. (#10843)
Browse files Browse the repository at this point in the history
I- Add basic estimation for RMM.
- Re-estimate after every sub-batch.
- Some debug logs for memory usage.
- Fix the locking mechanism in the memory allocator logger.
  • Loading branch information
trivialfis authored Sep 24, 2024
1 parent f3df0d0 commit bc69a3e
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 104 deletions.
60 changes: 47 additions & 13 deletions src/common/device_vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <cub/util_device.cuh> // for CurrentDevice
#include <map> // for map
#include <memory> // for unique_ptr
#include <mutex> // for defer_lock

#include "common.h" // for safe_cuda, HumanMemUnit
#include "xgboost/logging.h"
Expand All @@ -46,6 +47,12 @@ class MemoryLogger {
size_t num_deallocations{0};
std::map<void *, size_t> device_allocations;
void RegisterAllocation(void *ptr, size_t n) {
auto itr = device_allocations.find(ptr);
if (itr != device_allocations.cend()) {
LOG(WARNING) << "Attempting to allocate " << n << " bytes."
<< " that was already allocated\nptr:" << ptr << "\n"
<< dmlc::StackTrace();
}
device_allocations[ptr] = n;
currently_allocated_bytes += n;
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
Expand All @@ -56,7 +63,7 @@ class MemoryLogger {
auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) {
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
<< " that was never allocated\n"
<< " that was never allocated\nptr:" << ptr << "\n"
<< dmlc::StackTrace();
} else {
num_deallocations++;
Expand All @@ -70,18 +77,34 @@ class MemoryLogger {
std::mutex mutex_;

public:
void RegisterAllocation(void *ptr, size_t n) {
/**
* @brief Register the allocation for logging.
*
* @param lock Set to false if the allocator has locking machanism.
*/
void RegisterAllocation(void *ptr, size_t n, bool lock) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return;
}
std::lock_guard<std::mutex> guard(mutex_);
std::unique_lock guard{mutex_, std::defer_lock};
if (lock) {
guard.lock();
}
stats_.RegisterAllocation(ptr, n);
}
void RegisterDeallocation(void *ptr, size_t n) {
/**
* @brief Register the deallocation for logging.
*
* @param lock Set to false if the allocator has locking machanism.
*/
void RegisterDeallocation(void *ptr, size_t n, bool lock) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return;
}
std::lock_guard<std::mutex> guard(mutex_);
std::unique_lock guard{mutex_, std::defer_lock};
if (lock) {
guard.lock();
}
stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice());
}
size_t PeakMemory() const { return stats_.peak_allocated_bytes; }
Expand Down Expand Up @@ -140,11 +163,12 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
} catch (const std::exception &e) {
detail::ThrowOOMError(e.what(), n * sizeof(T));
}
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
// We can't place a lock here as template allocator is transient.
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T), true);
return ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
SuperT::deallocate(ptr, n);
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
Expand Down Expand Up @@ -193,11 +217,12 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
detail::ThrowOOMError(e.what(), n * sizeof(T));
}
}
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
// We can't place a lock here as template allocator is transient.
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T), true);
return thrust_ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
if (use_cub_allocator_) {
GetGlobalCachingAllocator().DeviceFree(ptr.get());
} else {
Expand Down Expand Up @@ -239,14 +264,15 @@ using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocato
*/
class LoggingResource : public rmm::mr::device_memory_resource {
rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()};
std::mutex lock_;

public:
LoggingResource() = default;
~LoggingResource() override = default;
LoggingResource(LoggingResource const &) = delete;
LoggingResource &operator=(LoggingResource const &) = delete;
LoggingResource(LoggingResource &&) noexcept = default;
LoggingResource &operator=(LoggingResource &&) noexcept = default;
LoggingResource(LoggingResource &&) noexcept = delete;
LoggingResource &operator=(LoggingResource &&) noexcept = delete;

[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { // NOLINT
return mr_;
Expand All @@ -256,9 +282,13 @@ class LoggingResource : public rmm::mr::device_memory_resource {
}

void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
try {
auto const ptr = mr_->allocate(bytes, stream);
GlobalMemoryLogger().RegisterAllocation(ptr, bytes);
GlobalMemoryLogger().RegisterAllocation(ptr, bytes, false);
return ptr;
} catch (rmm::bad_alloc const &e) {
detail::ThrowOOMError(e.what(), bytes);
Expand All @@ -268,8 +298,12 @@ class LoggingResource : public rmm::mr::device_memory_resource {

void do_deallocate(void *ptr, std::size_t bytes, // NOLINT
rmm::cuda_stream_view stream) override {
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
mr_->deallocate(ptr, bytes, stream);
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes);
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes, false);
}

[[nodiscard]] bool do_is_equal( // NOLINT
Expand Down
38 changes: 23 additions & 15 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018~2023 by XGBoost contributors
* Copyright 2018~2024, XGBoost contributors
*/
#include <thrust/binary_search.h>
#include <thrust/copy.h>
Expand Down Expand Up @@ -32,13 +32,12 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
num_rows, eps, &dummy_nlevel, &num_cuts);
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows);
}

size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns,
size_t max_bins, size_t nnz) {
size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins,
bst_idx_t nnz) {
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
auto if_dense = num_columns * per_column;
auto result = std::min(nnz, if_dense);
Expand Down Expand Up @@ -83,23 +82,31 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
return peak;
}

size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_rows,
bst_feature_t columns, size_t nnz, int device, size_t num_cuts,
bool has_weight) {
bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device,
size_t num_cuts, bool has_weight, std::size_t container_bytes) {
auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
// device available memory is not accurate when rmm is used.
return std::min(nnz, kIntMax);
// Device available memory is not accurate when rmm is used.
double total_mem = dh::TotalMemory(device) - container_bytes;
double total_f32 = total_mem / sizeof(float);
double n_max_used_f32 = std::max(total_f32 / 16.0, 1.0); // a quarter
if (shape.nnz > shape.Size()) {
// Unknown nnz
shape.nnz = shape.Size();
}
return std::min(static_cast<bst_idx_t>(n_max_used_f32), shape.nnz);
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
(void)container_bytes; // We known the remaining size when RMM is not used.

if (sketch_batch_num_elements == 0) {
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
if (sketch_batch_num_elements == detail::UnknownSketchNumElements()) {
auto required_memory =
RequiredMemory(shape.n_samples, shape.n_features, shape.nnz, num_cuts, has_weight);
// use up to 80% of available space
auto avail = dh::AvailableMemory(device) * 0.8;
if (required_memory > avail) {
sketch_batch_num_elements = avail / BytesPerElement(has_weight);
} else {
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
sketch_batch_num_elements = std::min(shape.Size(), shape.nnz);
}
}

Expand Down Expand Up @@ -338,8 +345,9 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
// Configure batch size based on available memory
std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(),
num_cuts_per_feature, has_weight);
sketch_batch_num_elements,
detail::SketchShape{info.num_row_, info.num_col_, info.num_nonzero_}, ctx->Ordinal(),
num_cuts_per_feature, has_weight, 0);

CUDAContext const* cuctx = ctx->CUDACtx();

Expand Down
Loading

0 comments on commit bc69a3e

Please sign in to comment.