Skip to content

Commit

Permalink
[GraphBolt][CUDA] IndexSelectCSC kernel launch config change. (#7056)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Feb 2, 2024
1 parent 50eb101 commit 177dc13
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions graphbolt/src/cuda/index_select_csc_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#include <numeric>

#include "./common.h"
#include "./max_uva_threads.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;
constexpr int BLOCK_SIZE = CUDA_MAX_NUM_THREADS;

// Given the in_degree array and a permutation, returns in_degree of the output
// and the permuted and modified in_degree of the input. The modified in_degree
Expand Down Expand Up @@ -130,7 +131,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::Tensor output_indices =
torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
const dim3 block(BLOCK_SIZE);
const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE);
const dim3 grid(
(std::min(edge_count_aligned, cuda::max_uva_threads.value_or(1 << 20)) +
BLOCK_SIZE - 1) /
BLOCK_SIZE);

// Find the smallest integer type to store the coo_aligned_rows tensor.
const int num_bits = cuda::NumberOfBits(num_nodes);
Expand Down
2 changes: 1 addition & 1 deletion graphbolt/src/cuda/index_select_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr,
input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else {
constexpr int BLOCK_SIZE = 512;
constexpr int BLOCK_SIZE = CUDA_MAX_NUM_THREADS;
dim3 block(BLOCK_SIZE, 1);
while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {
block.x >>= 1;
Expand Down

0 comments on commit 177dc13

Please sign in to comment.