diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index ee83ee5c6d3b8..01646949cd62d 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -46,6 +46,15 @@ #include namespace { +constexpr uint64_t getDefaultMaxThreadsPerBlock() { +#ifndef USE_ROCM + return 128; +#else + // bigger default + return 512; +#endif +} + template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -933,11 +942,13 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c selfAddDimSize, selfNumel, reduce_add, alpha_value); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); + uint64_t defaultMaxBlockThreads = getDefaultMaxThreadsPerBlock(); const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8))); const dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128)); const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8))); - const dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128)); + //On ROCm, std::min -> ::min did not work as expected on when outTotalSize>=2147483648 + dim3 largeIndexBlock( (sourceTotalSize < defaultMaxBlockThreads) ? sourceTotalSize : defaultMaxBlockThreads ); if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(source) && @@ -1106,11 +1117,13 @@ void index_reduce_func_cuda_impl( selfReduceDimSize, selfNumel, reduce_func, alpha_value); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); + uint64_t defaultMaxBlockThreads = getDefaultMaxThreadsPerBlock(); dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128)); dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8))); - dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128)); + //On ROCm, std::min -> ::min did not work as expected on when outTotalSize>=2147483648 + dim3 largeIndexBlock( (sourceTotalSize < defaultMaxBlockThreads) ? sourceTotalSize : defaultMaxBlockThreads ); if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(source) && @@ -1334,14 +1347,6 @@ tensorInfoLegacyIfScalar(cuda::detail::TensorInfo ti) { return ti; } -constexpr uint64_t getDefaultMaxThreadsPerBlock() { -#ifndef USE_ROCM - return 128; -#else - // bigger default - return 512; -#endif -} }