Skip to content

Commit

Permalink
updated device_description, gfx fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pemeliya committed Oct 30, 2024
1 parent 26c72d0 commit 0f73609
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 42 deletions.
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,7 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {

#if TF_HIPBLASLT
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true;
if (!cap.has_hipblaslt()) use_cudnn = true;
#endif

BlasScratchAllocator scratch_allocator(context);
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/matmul_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
#if TF_HIPBLASLT
if (!std::is_same_v<Scalar, float>) bCublasLtSupport = false;
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false;
if (!cap.has_hipblaslt()) bCublasLtSupport = false;
#endif
if (EnableCublasLtGemm() && bCublasLtSupport) {
static const int64_t max_scratch_size =
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/fusions/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
IrEmitterContext& ir_emitter_context) {
auto* module = builder->GetInsertBlock()->getModule();
if (IsAMDGPU(module) &&
ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr(
0, 6) == "gfx90a") {
ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
builder->CreateFence(
llvm::AtomicOrdering::SequentiallyConsistent,
builder->getContext().getOrInsertSyncScopeID("workgroup"));
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/fusions/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
IrEmitterContext& ir_emitter_context) {
auto* module = builder->GetInsertBlock()->getModule();
if (IsAMDGPU(module) &&
ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr(
0, 6) == "gfx90a") {
ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
builder->CreateFence(
llvm::AtomicOrdering::SequentiallyConsistent,
builder->getContext().getOrInsertSyncScopeID("workgroup"));
Expand Down
8 changes: 2 additions & 6 deletions third_party/xla/xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1863,12 +1863,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
TF_ASSIGN_OR_RETURN(bool output_is_column_major,
MatrixIsColumnMajor(instr, gemm_backend_config));

if (std::holds_alternative<se::RocmComputeCapability>(gpu_version_)) {
auto rocm_compute_capability_ =
std::get<se::RocmComputeCapability>(gpu_version_);

// as of ROCm 5.5, hipblaslt only supports MI200.
if (rocm_compute_capability_.gcn_arch_name().substr(0, 6) != "gfx90a") {
if (auto *rocm = std::get_if<se::RocmComputeCapability>(&gpu_version_)) {
if (!rocm->has_hipblaslt()) {
return false;
}
}
Expand Down
5 changes: 1 addition & 4 deletions third_party/xla/xla/service/gpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,7 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
void IrEmitter::MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering atomic_ordering,
const char* sync_scope_id) {
if (IsEmittingForAMDGPU() &&
(ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx90a" ||
ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx940" ||
ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx941" ||
ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr(0, 6) == "gfx942")) {
ir_emitter_context_->rocm_compute_capability().fence_before_barrier()) {
b_.CreateFence(atomic_ordering,
b_.getContext().getOrInsertSyncScopeID(sync_scope_id));
}
Expand Down
79 changes: 55 additions & 24 deletions third_party/xla/xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ struct CudaComputeCapability {
return !(*this < CudaComputeCapability{other_major, other_minor});
}

bool IsAtLeastVolta() const {
return major >= CudaComputeCapabilities::VOLTA;
}

bool IsAtLeastAmpere() const {
return major >= CudaComputeCapabilities::AMPERE;
}

bool IsAtLeastHopper() const {
return major >= CudaComputeCapabilities::HOPPER;
}

bool operator<(const CudaComputeCapability &other) const {
return ToPair() < other.ToPair();
}
Expand Down Expand Up @@ -147,38 +159,57 @@ class RocmComputeCapability {
return absl::StrJoin(kSupportedGfxVersions, ", ");
}

bool has_nhwc_layout_support() const {
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940",
"gfx941", "gfx942"};
bool gfx9_mi100() const { return gfx_version() == "gfx908"; }

bool gfx9_mi200() const { return gfx_version() == "gfx90a"; }

bool gfx9_mi300() const {
static constexpr absl::string_view kList[] = {"gfx940", "gfx941", "gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
}

bool has_bf16_dtype_support() const {
bool gfx9_mi100_or_later() const {
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940",
"gfx941", "gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
}

bool has_fast_fp16_support() const {
static constexpr absl::string_view kList[] = {"gfx906", "gfx908", "gfx90a",
"gfx940", "gfx941", "gfx942",
"gfx1030", "gfx1100"};
bool gfx9_mi200_or_later() const {
static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941",
"gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
}

bool has_mfma_instr_support() const {
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940",
"gfx941", "gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; }

bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; }

bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; }

bool gfx12_rx8900() const { return ((gfx_version() == "gfx1200") || (gfx_version() == "gfx1201")); }

bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); }

bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); }

bool has_fast_fp16_support() const {
return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() ||
gfx11_rx7900() || gfx12_rx8900();
}

bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }

bool has_fp16_atomics_support() const {
// TODO(rocm): Check. This should be the same as has_fast_fp16_support().
static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941",
"gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
return gfx9_mi200_or_later();
}

bool fence_before_barrier() const {
return gfx_version() != "gfx900" && gfx_version() != "gfx906";
}

bool has_hipblaslt() const { return gfx9_mi200_or_later(); }

RocmComputeCapabilityProto ToProto() const {
RocmComputeCapabilityProto proto;
proto.set_gcn_arch_name(gcn_arch_name_);
Expand All @@ -193,15 +224,15 @@ class RocmComputeCapability {
std::string gcn_arch_name_ = "gfx000"; // default to invalid arch.

static constexpr absl::string_view kSupportedGfxVersions[]{
"gfx900", // MI25
"gfx906", // MI50 / MI60
"gfx908", // MI100
"gfx90a", // MI200
"gfx940", // MI300
"gfx941", // MI300
"gfx942", // MI300
"gfx1030", // RX68xx / RX69xx
"gfx1100" // RX7900
"gfx900", // MI25
"gfx906", // MI50 / MI60
"gfx908", // MI100
"gfx90a", // MI200
"gfx940", "gfx941", "gfx942", // MI300
"gfx1030", // RX68xx / RX69xx
"gfx1100", // RX7900
"gfx1200", // RX8900
"gfx1201" // RX8900
};
};

Expand Down

0 comments on commit 0f73609

Please sign in to comment.