From d5b03bcb27a8467ff6a7738c895287b520eee36b Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Sat, 3 Feb 2024 01:49:54 +0300 Subject: [PATCH] [GraphBolt][CUDA] GPUCache performance fix. (#7073) --- graphbolt/src/cuda/gpu_cache.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/graphbolt/src/cuda/gpu_cache.cu b/graphbolt/src/cuda/gpu_cache.cu index 0a47bbbddc18..7c479fcc0c10 100644 --- a/graphbolt/src/cuda/gpu_cache.cu +++ b/graphbolt/src/cuda/gpu_cache.cu @@ -43,20 +43,19 @@ std::tuple GpuCache::Query( torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); auto missing_keys = torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); - cuda::CopyScalar missing_len; - auto stream = cuda::GetCurrentStream(); + auto allocator = cuda::GetAllocator(); + auto missing_len_device = allocator.AllocateStorage(1); cache_->Query( reinterpret_cast(keys.data_ptr()), keys.size(0), values.data_ptr(), reinterpret_cast(missing_index.data_ptr()), - reinterpret_cast(missing_keys.data_ptr()), missing_len.get(), - stream); + reinterpret_cast(missing_keys.data_ptr()), + missing_len_device.get(), cuda::GetCurrentStream()); values = values.view(torch::kByte) .slice(1, 0, num_bytes_) .view(dtype_) .view(shape_); - // To safely read missing_len, we synchronize - stream.synchronize(); + cuda::CopyScalar missing_len(missing_len_device.get()); missing_index = missing_index.slice(0, 0, static_cast(missing_len)); missing_keys = missing_keys.slice(0, 0, static_cast(missing_len)); return std::make_tuple(values, missing_index, missing_keys);