Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Disable gemm triton fusions for ROCm, until autotuner is funct… #50

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ class TritonTest : public GpuCodegenTest {

class TritonGemmTest : public TritonTest {
public:
se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(GetGpuComputeCapability())) {
GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled.";
}
}

DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
// Do not fall back to cuBLAS, we are testing Triton.
Expand Down
13 changes: 13 additions & 0 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ namespace {

class TritonGemmTest : public GpuCodegenTest {
public:
se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(GetGpuComputeCapability())) {
GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled.";
}
}

DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_cublas_fallback(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ class MixedTypeTest : public GpuCodegenTest,
.cuda_compute_capability();
}

se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(GetGpuComputeCapability())) {
GTEST_SKIP() << "Related fusions are not performed on ROCm without Triton.";
}
}

DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
// We are testing Triton, remove cuBLAS fallback for these tests.
Expand Down
5 changes: 2 additions & 3 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,9 +1387,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
const auto* rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);

if (debug_options.xla_gpu_enable_triton_gemm() &&
((cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
(cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE))) {
pipeline.AddPass<GemvRewriter>();
pipeline.AddPass<GemmFusion>(gpu_version);
}
Expand Down
7 changes: 7 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ class GpuCompilerTest : public HloTestBase {
return tensorflow::down_cast<GpuCompiler*>(compiler)
->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info);
}

const stream_executor::GpuComputeCapability& GpuComputeComp() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}
};

TEST_F(GpuCompilerTest, CompiledProgramsCount) {
Expand Down
Loading