diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index c96c5955da8e8..a8af7604aac88 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -135,6 +135,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/log", diff --git a/xla/service/gpu/launch_dimensions.cc b/xla/service/gpu/launch_dimensions.cc index 9bd0521923918..41f2fd4939d3b 100644 --- a/xla/service/gpu/launch_dimensions.cc +++ b/xla/service/gpu/launch_dimensions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -38,19 +39,19 @@ LaunchDimensions CalculateLaunchDimensions( num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); const int kWarpSchedulers = 4; - int64_t threads_per_block = std::min( + int64_t threads_per_block_x = std::min( gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block*2); - const auto& max_grid_size = gpu_device_info.grid_size_limit(); - LOG(INFO) << max_grid_size.x << "\n"; - LOG(INFO) << max_grid_size.y << "\n"; - LOG(INFO) << max_grid_size.z << "\n"; - LOG(INFO) << "num blocks: " << num_blocks << "\n"; - LOG(INFO) << "threads_per_block: " << threads_per_block << "\n"; - - num_blocks = std::min(num_blocks, static_cast(max_grid_size.x)); + int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x); + CHECK(num_blocks < gpu_device_info.block_dim_limit().x); + int threads_per_block_y = 1; + if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") { + while ((num_blocks * threads_per_block_x) > std::numeric_limits::max()) { + threads_per_block_x /= 2; + threads_per_block_y *= 2; + } + } return LaunchDimensions(se::BlockDim(num_blocks, 1, 1), - se::ThreadDim(threads_per_block, 1, 1)); + se::ThreadDim(threads_per_block_x, threads_per_block_y, 1)); } } // namespace gpu