Skip to content

Commit

Permalink
[GraphBolt][CUDA] GPUCache performance fix. (#7073)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Feb 2, 2024
1 parent 8568386 commit d5b03bc
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions graphbolt/src/cuda/gpu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
auto allocator = cuda::GetAllocator();
auto missing_len_device = allocator.AllocateStorage<size_t>(1);
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
reinterpret_cast<key_t *>(missing_keys.data_ptr()),
missing_len_device.get(), cuda::GetCurrentStream());
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
cuda::CopyScalar<size_t> missing_len(missing_len_device.get());
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys);
Expand Down

0 comments on commit d5b03bc

Please sign in to comment.