From 6d7265ba17be640da97c87d256919dadb4fe4d71 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 26 Feb 2024 23:08:18 +0000 Subject: [PATCH 01/11] use potrfDSCZ --- xla/service/gpu/cusolver_context.cc | 104 ++++++++++++++++++++++++---- 1 file changed, 91 insertions(+), 13 deletions(-) diff --git a/xla/service/gpu/cusolver_context.cc b/xla/service/gpu/cusolver_context.cc index 4343fec8b1987..f8db22495dd5f 100644 --- a/xla/service/gpu/cusolver_context.cc +++ b/xla/service/gpu/cusolver_context.cc @@ -53,12 +53,12 @@ struct GpuComplexT { // For ROCm, use hipsolver if the ROCm version >= 4.5 and // rocblas/rocsolver if the ROCm version < 4.5. -#if !TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA #define GPU_SOLVER_CONTEXT_PREFIX cusolverDn #define GPU_SOLVER_PREFIX cusolverDn -using gpuStream_t = cudaStream_t; +using gpuDataType_t = cudaDataType_t; template <> struct GpuComplexT> { @@ -78,9 +78,9 @@ struct GpuComplexT*> { typedef cuDoubleComplex* type; }; -#else +#elif TENSORFLOW_USE_ROCM -using gpuStream_t = hipStream_t; +using gpuDataType_t = hipDataType; #if TF_ROCM_VERSION >= 40500 #define GPU_SOLVER_CONTEXT_PREFIX se::wrap::hipsolver @@ -126,14 +126,14 @@ struct GpuComplexT*> { }; #endif // TF_ROCM_VERSION >= 40500 -#endif // !TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_USE_ROCM template inline typename GpuComplexT::type* ToDevicePointer(se::DeviceMemory p) { return static_cast::type*>(p.opaque()); } -#if !TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { switch (uplo) { case se::blas::UpperLower::kUpper: @@ -176,7 +176,8 @@ absl::Status ConvertStatus(cusolverStatus_t status) { return Unknown("Unknown cuSolver error"); } } -#else +#elif TENSORFLOW_USE_ROCM + #if TF_ROCM_VERSION >= 40500 hipsolverFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { switch (uplo) { @@ -185,7 +186,7 @@ hipsolverFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { case se::blas::UpperLower::kLower: return HIPSOLVER_FILL_MODE_LOWER; default: - LOG(FATAL) << "Invalid value of blas::UpperLower."; + LOG(FATAL) << "Invalid value of blas::UpperLower"; } } @@ -219,7 +220,7 @@ absl::Status ConvertStatus(hipsolverStatus_t status) { return Unknown("Unknown hipsolver error"); } } -#else +#else // TF_ROCM_VERSION < 40500 rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) { switch (uplo) { case se::blas::UpperLower::kUpper: @@ -341,12 +342,14 @@ void GpuSolverContext::Deleter::operator()(gpusolverHandle_t handle) { absl::StatusOr GpuSolverContext::PotrfBufferSize( PrimitiveType type, se::blas::UpperLower uplo, int n, int lda, int batch_size) { -#if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER + int size = -1; + auto gpu_uplo = GpuBlasUpperLower(uplo); +#if GOOGLE_CUDA size_t d_lwork = 0; /* size of workspace */ size_t h_lwork = 0; /* size of workspace */ - cudaDataType_t cuda_data_type; + gpuDataType_t cuda_data_type; switch (type) { case F32: { cuda_data_type = CUDA_R_32F; @@ -369,11 +372,44 @@ absl::StatusOr GpuSolverContext::PotrfBufferSize( PrimitiveType_Name(type)); } TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverDnXpotrf_bufferSize( - handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, cuda_data_type, + handle_.get(), nullptr, gpu_uplo, n, cuda_data_type, nullptr, lda, cuda_data_type, &d_lwork, &h_lwork))); size = static_cast(d_lwork); - // CUDA's potrfBatched needs space for the `as` array, which contains +#elif TENSORFLOW_USE_HIPSOLVER + switch (type) { + case F32: { + TF_RETURN_IF_ERROR(ConvertStatus( + GpuSolverSpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + case F64: { + TF_RETURN_IF_ERROR(ConvertStatus( + GpuSolverDpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + case C64: { + TF_RETURN_IF_ERROR(ConvertStatus( + GpuSolverCpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + case C128: { + TF_RETURN_IF_ERROR(ConvertStatus( + GpuSolverZpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + default: + return InvalidArgument("Invalid type for cholesky decomposition: %s", + PrimitiveType_Name(type)); + } +#endif // TENSORFLOW_USE_HIPSOLVER + +#if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER + // CUDA/HIP's potrfBatched needs space for the `as` array, which contains // batch_size pointers. Divide by sizeof(type) because this function returns // not bytes but a number of elements of `type`. int64_t potrf_batched_scratch = CeilOfRatio( @@ -434,6 +470,7 @@ absl::Status GpuSolverContext::PotrfBatched( ToDevicePointer(lapack_info), batch_size)); } +#if GOOGLE_CUDA absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory a, int lda, se::DeviceMemory lapack_info, @@ -477,6 +514,47 @@ absl::Status GpuSolverContext::Potrf( workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); return status; } +#elif TENSORFLOW_USE_HIPSOLVER +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + return ConvertStatus(GpuSolverDpotrf( + handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda, + nullptr, 0, + ToDevicePointer(lapack_info))); +} + +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + return ConvertStatus(GpuSolverSpotrf( + handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda, + nullptr, 0, + ToDevicePointer(lapack_info))); +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + return ConvertStatus(GpuSolverCpotrf( + handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda, + nullptr, 0, + ToDevicePointer(lapack_info))); +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + return ConvertStatus(GpuSolverZpotrf( + handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda, + nullptr, 0, + ToDevicePointer(lapack_info))); +} +#endif // TENSORFLOW_USE_HIPSOLVER } // namespace gpu } // namespace xla From d965fd3a817838c4b7440868afb6884ea011b078 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Tue, 27 Feb 2024 19:50:01 +0000 Subject: [PATCH 02/11] fixing use_hgemm_alt_impl_ flag --- xla/stream_executor/rocm/rocm_blas.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index d4c9b774682c6..c35117dbf3c9f 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -126,7 +126,7 @@ bool ROCMBlas::Init() { if (result == hipSuccess) { auto cap = RocmComputeCapability(props.gcnArchName); has_mfma_ = cap.has_mfma_instr_support(); - use_hgemm_alt_impl_ = (cap.gfx_version() == "90a"); + use_hgemm_alt_impl_ = (cap.gfx_version() == "gfx90a"); } return true; From e2c2c85a1f7d02bd5dd669d0b284a6a0a4fffa18 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 26 Feb 2024 23:08:18 +0000 Subject: [PATCH 03/11] ncclCommSplit is from rocm6.0 --- xla/service/gpu/nccl_api.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xla/service/gpu/nccl_api.cc b/xla/service/gpu/nccl_api.cc index e09cac7c87919..10b2d0043cc84 100644 --- a/xla/service/gpu/nccl_api.cc +++ b/xla/service/gpu/nccl_api.cc @@ -371,6 +371,8 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, VLOG(1) << "Initialize NCCL communicator for " << ranks.size() << " devices; hash(id)=" << absl::HashOf(clique_id); +#if !defined(TENSORFLOW_USE_ROCM) || (defined(TENSORFLOW_USE_ROCM) && + TF_ROCM_VERSION > 50700) ncclConfig_t comm_config = NCCL_CONFIG_INITIALIZER; comm_config.splitShare = config.split_share; if (config.max_nchannels > 0) { @@ -399,6 +401,11 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, TF_RETURN_IF_ERROR(GroupEnd()); return comms; +#else + return absl::UnimplementedError( + absl::StrFormat("%s:%d: NCCL operation ncclCommInitRankConfig not implemented", + __FILE__, __LINE__)); +#endif } absl::StatusOr> DefaultNcclApi::CommSplit( From 5cc1ead35b6ac554abb601006eceff5e84b8f068 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 26 Feb 2024 23:13:11 +0000 Subject: [PATCH 04/11] fixed dd7604b8faab933099548452677836b72efdcebf on absl::Status --- xla/stream_executor/rocm/rocm_executor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index c55a8c9eb7a28..879f7daff02ea 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -639,7 +639,7 @@ absl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -bool GpuExecutor::Memcpy(Stream* stream, void* host_dst, +absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { bool ok = GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, AsROCmDevicePtr(gpu_src), size, @@ -652,7 +652,7 @@ bool GpuExecutor::Memcpy(Stream* stream, void* host_dst, return absl::OkStatus(); } -bool GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, +absl::Status GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { bool ok = GpuDriver::AsynchronousMemcpyH2D(context_, AsROCmDevicePtr(gpu_dst), host_src, size, From f863d4d19faa3cfa7c46210fb0a1b26602aab40d Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 26 Feb 2024 22:42:09 +0000 Subject: [PATCH 05/11] fixed c95205516554a53d0c1c20399c94b9ad456abb4f --- xla/stream_executor/rocm/rocm_executor.cc | 7 +++---- xla/stream_executor/rocm/rocm_kernel.cc | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 879f7daff02ea..3a65cfce6d32e 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -315,7 +315,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, kernel->set_metadata(kernel_metadata); } kernel->set_name(*kernel_name); - kernel->set_kernel_args_packing(spec.kernel_args_packing()); + kernel->set_args_packing(spec.kernel_args_packing()); return absl::OkStatus(); } @@ -354,8 +354,7 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, } } - if (rocm_kernel->GetPreferredCacheConfig() != - KernelCacheConfig::kNoPreference) { + if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( hipfunc, rocm_kernel->GetGpuCacheConfig())); } @@ -377,7 +376,7 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, if (packed_args) return launch(*packed_args); if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.kernel_args_packing(); + auto& pack = kernel.args_packing(); if (!pack) { return absl::InternalError( "Kernel is missing a custom arguments packing function for device " diff --git a/xla/stream_executor/rocm/rocm_kernel.cc b/xla/stream_executor/rocm/rocm_kernel.cc index 2b55c5ff304e3..74f2def14e22b 100644 --- a/xla/stream_executor/rocm/rocm_kernel.cc +++ b/xla/stream_executor/rocm/rocm_kernel.cc @@ -19,7 +19,7 @@ namespace stream_executor { namespace gpu { hipFuncCache_t GpuKernel::GetGpuCacheConfig() const { - switch (preferred_cache_config_) { + switch (cache_config()) { case KernelCacheConfig::kNoPreference: return hipFuncCachePreferNone; case KernelCacheConfig::kPreferShared: @@ -30,7 +30,7 @@ hipFuncCache_t GpuKernel::GetGpuCacheConfig() const { return hipFuncCachePreferEqual; default: LOG(FATAL) << "Unknown KernelCacheConfig" - << static_cast(preferred_cache_config_); + << static_cast(cache_config()); } } From 9ed11935c7acc4cf7026cdd0548d6dd89d10c6a5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 26 Feb 2024 23:00:38 +0000 Subject: [PATCH 06/11] add GraphNodeGetDependencies in rocm --- xla/stream_executor/rocm/rocm_driver.cc | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 2d075f06cd189..0fa0b6c5a1b48 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -643,6 +643,26 @@ static std::string_view StreamCaptureModeToString( return absl::OkStatus(); } +absl::StatusOr> +GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { + VLOG(2) << "Get HIP graph node " << node << " dependencies"; + + std::vector dependencies; + + size_t num_dependencies = 0; + RETURN_IF_ROCM_ERROR( + hipGraphNodeGetDependencies(node, nullptr, &num_dependencies), + "Failed to get HIP graph node depedencies size"); + + dependencies.resize(num_dependencies, nullptr); + RETURN_IF_ROCM_ERROR( + hipGraphNodeGetDependencies(node, dependencies.data(), &num_dependencies), + "Failed to get HIP graph node depedencies"); + + return dependencies; +} + + /* static */ absl::Status GpuDriver::DestroyGraphExec(hipGraphExec_t exec) { VLOG(2) << "Destroying HIP executable graph" << exec; RETURN_IF_ROCM_ERROR(wrap::hipGraphExecDestroy(exec), From f53679fdb985e4599092ed36c164c603860ae7e0 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 26 Feb 2024 23:34:01 +0000 Subject: [PATCH 07/11] no break --- xla/service/gpu/nccl_api.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xla/service/gpu/nccl_api.cc b/xla/service/gpu/nccl_api.cc index 10b2d0043cc84..d218af89430f6 100644 --- a/xla/service/gpu/nccl_api.cc +++ b/xla/service/gpu/nccl_api.cc @@ -371,8 +371,7 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, VLOG(1) << "Initialize NCCL communicator for " << ranks.size() << " devices; hash(id)=" << absl::HashOf(clique_id); -#if !defined(TENSORFLOW_USE_ROCM) || (defined(TENSORFLOW_USE_ROCM) && - TF_ROCM_VERSION > 50700) +#if !defined(TENSORFLOW_USE_ROCM) || (defined(TENSORFLOW_USE_ROCM) && TF_ROCM_VERSION > 50700) ncclConfig_t comm_config = NCCL_CONFIG_INITIALIZER; comm_config.splitShare = config.split_share; if (config.max_nchannels > 0) { From 348c3b8a8d790ecc8b8c2fd0f8fbd31884fa7a3c Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Fri, 23 Feb 2024 16:05:17 +0000 Subject: [PATCH 08/11] [ROCm] Use addresspace(5) for allocas on ROCM also --- xla/service/llvm_ir/llvm_util.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 3a9ab3bb6f9c3..905105cc7c747 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -422,7 +422,8 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, function->getEntryBlock().getFirstInsertionPt()); llvm::Module* module = b->GetInsertBlock()->getModule(); // Explicitly set local addrspace for SPIR backend. - int addrspace = llvm::Triple(module->getTargetTriple()).isSPIR() ? 5 : 0; + llvm::Triple target(module->getTargetTriple()); + int addrspace = target.isSPIR() || target.isAMDGPU() ? 5 : 0; llvm::AllocaInst* alloca = b->CreateAlloca(type, addrspace, element_count, AsStringRef(name)); if (alignment != 0) { From b30236d3c7358efbdb277fc3ec1f45b841544461 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Fri, 23 Feb 2024 07:13:48 +0000 Subject: [PATCH 09/11] [ROCm] Don't use CUDA PTX for ROCM in ComputationIdCmd --- xla/service/gpu/runtime/command_buffer_cmd.cc | 6 ++++++ xla/service/gpu/runtime/command_buffer_cmd.h | 2 ++ 2 files changed, 8 insertions(+) diff --git a/xla/service/gpu/runtime/command_buffer_cmd.cc b/xla/service/gpu/runtime/command_buffer_cmd.cc index ab59ad00d2ea3..1c1c81e378e37 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -496,6 +496,7 @@ CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() { absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, StateManager& state) { +#if defined(GOOGLE_CUDA) { absl::MutexLock lock(&mutex_); if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); @@ -508,6 +509,7 @@ absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, absl::MutexLock lock(&mutex_); memset_kernels_.emplace(params.executor, std::move(kernel)); +#endif // GOOGLE_CUDA return absl::OkStatus(); } @@ -534,6 +536,7 @@ absl::Status ComputationIdCmd::Record( << "; execution_scope_id=" << execution_scope_id.value(); VLOG(5) << " Id: " << dest_ << " (" << dst.opaque() << ")"; +#if defined(GOOGLE_CUDA) se::Kernel* memset_kernel = [&] { absl::MutexLock lock(&mutex_); return memset_kernels_[execute_params.stream->parent()].get(); @@ -547,6 +550,9 @@ absl::Status ComputationIdCmd::Record( auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); return command_buffer->Launch(execution_scope_id, se::ThreadDim(1), se::BlockDim(1), *memset_kernel, *args); +#else + return command_buffer->Memset(execution_scope_id, &dst, value, /*num_elements=*/1); +#endif // GOOGLE_CUDA } //===----------------------------------------------------------------------===// diff --git a/xla/service/gpu/runtime/command_buffer_cmd.h b/xla/service/gpu/runtime/command_buffer_cmd.h index 559f6a93f082c..b71c95c808a39 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/xla/service/gpu/runtime/command_buffer_cmd.h @@ -410,6 +410,7 @@ class ComputationIdCmd : public CommandBufferCmd { BufferAllocation::Slice dest_; Kind kind_; +#if defined(GOOGLE_CUDA) // Command sequence can be recorded concurrently for multiple command buffers // on different stream executors and we need to synchronize mutable state. absl::Mutex mutex_; @@ -421,6 +422,7 @@ class ComputationIdCmd : public CommandBufferCmd { // memset. This should be removed when bug is fixed in CUDA. absl::flat_hash_map> memset_kernels_ ABSL_GUARDED_BY(mutex_); +#endif // GOOGLE_CUDA }; //===----------------------------------------------------------------------===// From 426e59e131feef1b3d6b2736f455803e1e58a8c3 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Fri, 19 Jan 2024 14:47:57 +0000 Subject: [PATCH 10/11] added ConvBfloat16Support HLO pass --- xla/service/gpu/amdgpu_compiler.cc | 37 ++++++++++++++++++++++++++ xla/stream_executor/rocm/rocm_driver.h | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 0a985871914cd..4eff13885d8b1 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" #include "xla/service/dot_dimension_merger.h" +#include "xla/service/float_normalization.h" #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" @@ -53,6 +54,36 @@ limitations under the License. namespace xla { namespace gpu { +namespace { + +struct ConvBfloat16Support : public FloatSupport { + + explicit ConvBfloat16Support( + const se::RocmComputeCapability& rocm) + : FloatSupport(BF16), + // TODO: MIOpen does not support bf16 convolutions yet + is_conv_bf16_supported_(rocm.has_bf16_dtype_support()) {} + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + // Skip all HLOs other than convolutions. + return (hlo.opcode() != HloOpcode::kConvolution); + } + + private: + bool is_conv_bf16_supported_; +}; + +} // namespace + absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, @@ -63,6 +94,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + + // Convert upsupported bf16 convolutions to f32. + ConvBfloat16Support conv_bf16_support( + std::get(gpu_version)); + pipeline.AddPass(&conv_bf16_support); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/xla/stream_executor/rocm/rocm_driver.h b/xla/stream_executor/rocm/rocm_driver.h index f94d9bf8f0c52..4cb76dd8151bf 100644 --- a/xla/stream_executor/rocm/rocm_driver.h +++ b/xla/stream_executor/rocm/rocm_driver.h @@ -29,7 +29,7 @@ namespace stream_executor { namespace gpu { // Formats hipError_t to output prettified values into a log stream. // Error summaries taken from: -string ToString(hipError_t result); +std::string ToString(hipError_t result); // GpuContext wraps the device_ordinal and hipCtx_t handle. class GpuContext { From 01f31a5b2b6a9745372bc2e85b069080db5d2224 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Wed, 6 Mar 2024 15:49:15 +0000 Subject: [PATCH 11/11] Add CudnnPadForConvolutions and CudnnVecotrizeConvolutions HLO pass --- xla/service/gpu/BUILD | 3 +++ xla/service/gpu/amdgpu_compiler.cc | 17 +++++++++++++++-- xla/service/gpu/cudnn_pad_for_convolutions.cc | 11 ++++++++--- xla/service/gpu/cudnn_pad_for_convolutions.h | 5 ++++- xla/service/gpu/cudnn_support_utils.cc | 10 ++++++---- xla/service/gpu/cudnn_support_utils.h | 2 +- xla/service/gpu/cudnn_vectorize_convolutions.cc | 9 ++++++--- xla/service/gpu/cudnn_vectorize_convolutions.h | 7 ++++++- xla/stream_executor/device_description.h | 8 ++++++++ 9 files changed, 57 insertions(+), 15 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 104937f2b8f6f..75486ee8ea409 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3907,6 +3907,9 @@ cc_library( ":conv_algorithm_picker", ":cublas_pad_for_gemms", ":cublas_padding_requirements", + ":cudnn_pad_for_convolutions", + ":cudnn_simplify_padding", + ":cudnn_vectorize_convolutions", ":cusolver_rewriter", ":gemm_algorithm_picker", ":gemm_rewriter", diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 4eff13885d8b1..e324adac1cab5 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -28,6 +28,9 @@ limitations under the License. #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" +#include "xla/service/gpu/cudnn_pad_for_convolutions.h" +#include "xla/service/gpu/cudnn_simplify_padding.h" +#include "xla/service/gpu/cudnn_vectorize_convolutions.h" #include "xla/service/gpu/cusolver_rewriter.h" #include "xla/service/gpu/gemm_algorithm_picker.h" #include "xla/service/gpu/gemm_rewriter.h" @@ -88,6 +91,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) { + auto rocm_compute_capability = + std::get(gpu_version); // Convert convolutions into CustomCalls to MIOpen, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); @@ -96,13 +101,14 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( /*allow_mixed_precision=*/false); // Convert upsupported bf16 convolutions to f32. - ConvBfloat16Support conv_bf16_support( - std::get(gpu_version)); + ConvBfloat16Support conv_bf16_support(rocm_compute_capability); pipeline.AddPass(&conv_bf16_support); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(rocm_compute_capability); + pipeline.AddPass(rocm_compute_capability); // The conv padding/vectorization passes which we need to get rid of. They // also leave behind unnecessary tuple/get-tuple-element pairs that @@ -119,6 +125,13 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(options); + // CudnnSimplifyPadding gets rid of some padding introduced by + // CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The + // pattern-matches in this pass need to be run after inlining and simplifying + // tuples from CudnnVectorizeConvolutions. We also need to run algsimp to + // e.g. clean up unnecessary nop `convert`s. + pipeline.AddPass(); + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); diff --git a/xla/service/gpu/cudnn_pad_for_convolutions.cc b/xla/service/gpu/cudnn_pad_for_convolutions.cc index e104eea0530e6..74a638bd7d676 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions.cc +++ b/xla/service/gpu/cudnn_pad_for_convolutions.cc @@ -315,7 +315,7 @@ static absl::StatusOr TryResolvePaddedShapesForTensorCore( // Adds padding to cudnn integer convolutions to make input and output feature // maps multiples of pad_to (usually 4 or 32). absl::StatusOr TryResolvePaddedShapesForIntegerConvolution( - int pad_to, const se::CudaComputeCapability& compute_capability, + int pad_to, const se::GpuComputeCapability& compute_capability, HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, Shape* new_result_shape_ptr) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); @@ -490,13 +490,16 @@ absl::StatusOr CudnnPadForConvolutions::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; + auto *ccc = std::get_if(&compute_capability_); for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { // On Turing and later (sm75+), pad to multiples of 32 bytes if possible, // because that lets us use the fast int8x32 data type. bool local_changed = false; - if (compute_capability_.IsAtLeast(7, 5)) { + bool isSM75_and_later = false; + if (ccc) isSM75_and_later = ccc->IsAtLeast(7, 5); + if (isSM75_and_later || se::isROCm(compute_capability_)) { TF_ASSIGN_OR_RETURN( local_changed, ResolveAndPad(conv, absl::bind_front( @@ -512,7 +515,9 @@ absl::StatusOr CudnnPadForConvolutions::Run( } changed |= local_changed; } - if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) { + bool isVOLTA = false; + if (ccc) isVOLTA = ccc->IsAtLeast(se::CudaComputeCapability::VOLTA); + if (isVOLTA || se::isROCm(compute_capability_)) { for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { TF_ASSIGN_OR_RETURN( bool local_changed, diff --git a/xla/service/gpu/cudnn_pad_for_convolutions.h b/xla/service/gpu/cudnn_pad_for_convolutions.h index e37f45f3e48ad..571f4afdb1698 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions.h +++ b/xla/service/gpu/cudnn_pad_for_convolutions.h @@ -34,6 +34,9 @@ class CudnnPadForConvolutions : public HloModulePass { explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability) : compute_capability_(compute_capability) {} + explicit CudnnPadForConvolutions(se::RocmComputeCapability compute_capability) + : compute_capability_(compute_capability) {} + absl::string_view name() const override { return "cudnn_pad_for_convolutions"; } @@ -44,7 +47,7 @@ class CudnnPadForConvolutions : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - const se::CudaComputeCapability compute_capability_; + const se::GpuComputeCapability compute_capability_; }; } // namespace gpu diff --git a/xla/service/gpu/cudnn_support_utils.cc b/xla/service/gpu/cudnn_support_utils.cc index 7f9cf7074a58a..3294e64c65f38 100644 --- a/xla/service/gpu/cudnn_support_utils.cc +++ b/xla/service/gpu/cudnn_support_utils.cc @@ -33,7 +33,7 @@ namespace xla { namespace gpu { absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, HloCustomCallInstruction& conv, int vector_size) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(&conv)); const Shape& input_shape = conv.operand(0)->shape(); @@ -50,9 +50,11 @@ absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( // Require cc6.1+ for any vectorized integer convolutions // Require cc7.5+ for any IMMA convolutions - if ((vector_size == 32 && !compute_capability.IsAtLeast(7, 5)) || - !compute_capability.IsAtLeast(6, 1)) { - VLOG(3) << "Compute capability " << compute_capability.ToString() + bool isCUDA = std::holds_alternative(compute_capability); + auto cuda_compute_capability = std::get(compute_capability); + if ((vector_size == 32 && !cuda_compute_capability.IsAtLeast(7, 5)) || + !cuda_compute_capability.IsAtLeast(6, 1)) { + VLOG(3) << "Compute capability " << cuda_compute_capability.ToString() << " is not sufficent for int8x" << vector_size << " vectorization."; return false; diff --git a/xla/service/gpu/cudnn_support_utils.h b/xla/service/gpu/cudnn_support_utils.h index f0132f13cd26b..03cd22219b620 100644 --- a/xla/service/gpu/cudnn_support_utils.h +++ b/xla/service/gpu/cudnn_support_utils.h @@ -32,7 +32,7 @@ namespace gpu { // This function does not guarantee that a convolution will be padded and/or // vectorized. It only checks that it is a valid candiate for such optimization. absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, HloCustomCallInstruction& conv, int vector_size); // Represents configuration for the reshape-transpose-reshape operations that diff --git a/xla/service/gpu/cudnn_vectorize_convolutions.cc b/xla/service/gpu/cudnn_vectorize_convolutions.cc index cecf996c3928c..99e3ae9464cfe 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions.cc +++ b/xla/service/gpu/cudnn_vectorize_convolutions.cc @@ -335,7 +335,7 @@ absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, // (The dimensions can appear in any order; which is N/C/etc is determined by // the convolutions' dnums.) static absl::StatusOr TryRevectorizeConv( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv, int vect_size) { const Shape& input_shape = conv->operand(0)->shape(); @@ -496,7 +496,7 @@ static absl::StatusOr TryRevectorizeConv( // This requires that C be a multiple of vect_size. CudnnPadForConvolutions can // add padding to make this true. static absl::StatusOr TryVectorizeConv( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv, int64_t vect_size) { const Shape& input_shape = conv->operand(0)->shape(); @@ -625,7 +625,10 @@ absl::StatusOr CudnnVectorizeConvolutions::Run( // Try to (re)vectorize to int8x32 if this is an sm75+ GPU. If we can't, // fall back to int8x4. bool local_changed = false; - if (compute_capability_.IsAtLeast(7, 5)) { + auto *ccc = std::get_if(&compute_capability_); + bool isSM75_and_later = false; + if (ccc) isSM75_and_later = ccc->IsAtLeast(7, 5); + if (isSM75_and_later || se::isROCm(compute_capability_)) { TF_ASSIGN_OR_RETURN( local_changed, TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32)); diff --git a/xla/service/gpu/cudnn_vectorize_convolutions.h b/xla/service/gpu/cudnn_vectorize_convolutions.h index 6dde84e023ad7..8cfa3e448ad69 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions.h +++ b/xla/service/gpu/cudnn_vectorize_convolutions.h @@ -52,6 +52,11 @@ class CudnnVectorizeConvolutions : public HloModulePass { : compute_capability_(compute_capability), cudnn_version_(cudnn_version) {} + explicit CudnnVectorizeConvolutions( + se::RocmComputeCapability compute_capability) + : compute_capability_(compute_capability) {} + + absl::string_view name() const override { return "cudnn_vectorize_convolutions"; } @@ -61,7 +66,7 @@ class CudnnVectorizeConvolutions : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - const se::CudaComputeCapability compute_capability_; + const se::GpuComputeCapability compute_capability_; const se::dnn::VersionInfo cudnn_version_; }; diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 6d06755956d9f..d3d2ac48f22fe 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -223,6 +223,14 @@ class RocmComputeCapability { using GpuComputeCapability = std::variant; +static inline bool isCUDA(const GpuComputeCapability& gcc) { + return std::holds_alternative(gcc); +} + +static inline bool isROCm(const GpuComputeCapability& gcc) { + return std::holds_alternative(gcc); +} + // Data that describes the execution target of the StreamExecutor, in terms of // important logical parameters. These include dimensionality limits and // physical parameters of interest, such as number of cores present on the