diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/transpose.cc index 877d282b16db36..5e97e432a83130 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose.cc @@ -48,8 +48,7 @@ void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context) { auto* module = builder->GetInsertBlock()->getModule(); if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { + ir_emitter_context.rocm_compute_capability().fence_before_barrier()) { builder->CreateFence( llvm::AtomicOrdering::SequentiallyConsistent, builder->getContext().getOrInsertSyncScopeID("workgroup")); diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc index a8984acdac5974..7676f54f8aab85 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc @@ -1863,12 +1863,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN(bool output_is_column_major, MatrixIsColumnMajor(instr, gemm_backend_config)); - if (std::holds_alternative(gpu_version_)) { - auto rocm_compute_capability_ = - std::get(gpu_version_); - - // as of ROCm 5.5, hipblaslt only supports MI200. - if (rocm_compute_capability_.gcn_arch_name().substr(0, 6) != "gfx90a") { + if (auto *rocm = std::get_if(&gpu_version_)) { + if (!rocm->has_hipblaslt()) { return false; } } diff --git a/third_party/xla/xla/service/gpu/ir_emitter.cc b/third_party/xla/xla/service/gpu/ir_emitter.cc index d0401d06052801..3f3caa7f78bc8c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter.cc @@ -264,10 +264,7 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion, void IrEmitter::MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering atomic_ordering, const char* sync_scope_id) { if (IsEmittingForAMDGPU() && - (ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx90a" || - ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx940" || - ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx941" || - ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx942")) { + ir_emitter_context_->rocm_compute_capability().fence_before_barrier()) { b_.CreateFence(atomic_ordering, b_.getContext().getOrInsertSyncScopeID(sync_scope_id)); } diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 8e48e8c4ab74ee..8e126ba9d129df 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -68,6 +68,18 @@ struct CudaComputeCapability { return !(*this < CudaComputeCapability{other_major, other_minor}); } + bool IsAtLeastVolta() const { + return major >= CudaComputeCapabilities::VOLTA; + } + + bool IsAtLeastAmpere() const { + return major >= CudaComputeCapabilities::AMPERE; + } + + bool IsAtLeastHopper() const { + return major >= CudaComputeCapabilities::HOPPER; + } + bool operator<(const CudaComputeCapability &other) const { return ToPair() < other.ToPair(); } @@ -173,13 +185,15 @@ class RocmComputeCapability { bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; } + bool gfx12_rx8900() const { return ((gfx_version() == "gfx1200") || (gfx_version() == "gfx1201")); } + bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } bool has_fast_fp16_support() const { return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() || - gfx11_rx7900(); + gfx11_rx7900() || gfx12_rx8900(); } bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } @@ -217,7 +231,9 @@ class RocmComputeCapability { "gfx90a", // MI200 "gfx940", "gfx941", "gfx942", // MI300 "gfx1030", // RX68xx / RX69xx - "gfx1100" // RX7900 + "gfx1100", // RX7900 + "gfx1200", // RX8900 + "gfx1201" // RX8900 }; };