Skip to content

Commit

Permalink
updated device_description, gfx fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pemeliya committed Nov 1, 2024
1 parent c33aa29 commit 8959110
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/fusions/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
IrEmitterContext& ir_emitter_context) {
auto* module = builder->GetInsertBlock()->getModule();
if (IsAMDGPU(module) &&
ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr(
0, 6) == "gfx90a") {
ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
builder->CreateFence(
llvm::AtomicOrdering::SequentiallyConsistent,
builder->getContext().getOrInsertSyncScopeID("workgroup"));
Expand Down
8 changes: 2 additions & 6 deletions third_party/xla/xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1863,12 +1863,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
TF_ASSIGN_OR_RETURN(bool output_is_column_major,
MatrixIsColumnMajor(instr, gemm_backend_config));

if (std::holds_alternative<se::RocmComputeCapability>(gpu_version_)) {
auto rocm_compute_capability_ =
std::get<se::RocmComputeCapability>(gpu_version_);

// as of ROCm 5.5, hipblaslt only supports MI200.
if (rocm_compute_capability_.gcn_arch_name().substr(0, 6) != "gfx90a") {
if (auto *rocm = std::get_if<se::RocmComputeCapability>(&gpu_version_)) {
if (!rocm->has_hipblaslt()) {
return false;
}
}
Expand Down
5 changes: 1 addition & 4 deletions third_party/xla/xla/service/gpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,7 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
void IrEmitter::MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering atomic_ordering,
const char* sync_scope_id) {
if (IsEmittingForAMDGPU() &&
(ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx90a" ||
ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx940" ||
ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx941" ||
ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx942")) {
ir_emitter_context_->rocm_compute_capability().fence_before_barrier()) {
b_.CreateFence(atomic_ordering,
b_.getContext().getOrInsertSyncScopeID(sync_scope_id));
}
Expand Down
20 changes: 18 additions & 2 deletions third_party/xla/xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ struct CudaComputeCapability {
return !(*this < CudaComputeCapability{other_major, other_minor});
}

bool IsAtLeastVolta() const {
return major >= CudaComputeCapabilities::VOLTA;
}

bool IsAtLeastAmpere() const {
return major >= CudaComputeCapabilities::AMPERE;
}

bool IsAtLeastHopper() const {
return major >= CudaComputeCapabilities::HOPPER;
}

bool operator<(const CudaComputeCapability &other) const {
return ToPair() < other.ToPair();
}
Expand Down Expand Up @@ -173,13 +185,15 @@ class RocmComputeCapability {

bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; }

bool gfx12_rx8900() const { return ((gfx_version() == "gfx1200") || (gfx_version() == "gfx1201")); }

bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); }

bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); }

bool has_fast_fp16_support() const {
return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() ||
gfx11_rx7900();
gfx11_rx7900() || gfx12_rx8900();
}

bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }
Expand Down Expand Up @@ -217,7 +231,9 @@ class RocmComputeCapability {
"gfx90a", // MI200
"gfx940", "gfx941", "gfx942", // MI300
"gfx1030", // RX68xx / RX69xx
"gfx1100" // RX7900
"gfx1100", // RX7900
"gfx1200", // RX8900
"gfx1201" // RX8900
};
};

Expand Down

0 comments on commit 8959110

Please sign in to comment.