Skip to content

Commit

Permalink
[XLA:GPU] Enable F32_F32_F32 algorithm for triton gemm fusion
Browse files Browse the repository at this point in the history
Add F32_F32_F32 algorithms to the list of algorithms for which we could use triton fusions. We do that in triton_support_legacy.cc. In gemm_fusion.cc we allow to use triton even if there is nothing else to fuse with this dot.

Later, during the auto tuning stage, we check if the speed of triton fusion. If it is worse than unfused cublas then the fusion will be inlined and cublas will be used.

PiperOrigin-RevId: 701014493
  • Loading branch information
loislo authored and tensorflower-gardener committed Nov 28, 2024
1 parent 513e824 commit cee0dce
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,13 @@ xla_test(
"gpu_h100",
"gpu_amd_any",
] + if_oss(["gpu_b100"]),
shard_count = 20,
shard_count = 30,
tags = [
"no_mac",
],
deps = [
":kernel_name_tracer",
":triton_test_utils",
"//xla:array2d",
"//xla:autotuning_proto_cc",
"//xla:error_spec",
"//xla:literal",
Expand All @@ -515,6 +514,7 @@ xla_test(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -1403,8 +1404,7 @@ class CSVWriter {
};

class AlgorithmsSupportTest
: public CanHandleArguments,
public WithParamInterface<CanHandleTestsParams::TupleType>,
: public WithParamInterface<CanHandleTestsParams::TupleType>,
public AlgorithmTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
Expand Down Expand Up @@ -1598,6 +1598,76 @@ TEST_P(AlgorithmsSupportTest, DotNC) {
DumpResults(csv, "backend_support_matrix");
}

TEST_P(AlgorithmsSupportTest, IsDotAlgorithmSupportedByTriton) {
// Here we test which dot algorithm is supported by triton.
// In case of a change you need to update the expected results.
const std::string kHloText = R"(
HloModule ${module_name}
ENTRY e {
p0 = f32[${m},${k}] parameter(0)
p1 = f32[${k},${n}] parameter(1)
ROOT dot = f32[${m},${n}] dot(p0, p1),
lhs_contracting_dims={1},
rhs_contracting_dims={0},
algorithm=${algorithm}
}
)";
auto m = 128;
auto n = 128;
auto k = 128;
auto run = [&](std::string backend, std::string_view pattern,
const DebugOptions& options) -> absl::StatusOr<bool> {
auto test_name = absl::StrReplaceAll(TestName(), {{"/", "_"}});
auto module_name = absl::StrCat(test_name, "_", backend, "_", m, "_", kMaxK,
"_", n, "_", algorithm_);
auto module = GetModule(kHloText,
{{"${module_name}", module_name},
{"${algorithm}", algorithm_},
{"${m}", absl::StrCat(m)},
{"${n}", absl::StrCat(n)},
{"${k}", absl::StrCat(k)}},
options);
if (!module.ok()) {
return module.status();
}
std::string module_text = module.value()->ToString();
if (!Run(std::move(module.value()), false)) {
return absl::InternalError("failed to run module");
}
return absl::StrContains(module_text, pattern);
};

auto result_or_status = run("triton", kTritonGemmPattern, triton_options_);
switch (std::get<0>(GetParam())) {
case PC::ALG_UNSET:
case PC::ALG_DOT_TF32_TF32_F32:
case PC::ALG_DOT_BF16_BF16_F32:
case PC::ALG_DOT_BF16_BF16_F32_X3:
case PC::ALG_DOT_BF16_BF16_F32_X6:
case PC::ALG_DOT_F32_F32_F32:
EXPECT_TRUE(result_or_status.status().ok())
<< "failed to compile " << algorithm_;
EXPECT_TRUE(result_or_status.value())
<< "wrong result for " << algorithm_;
break;
case PC::ALG_DOT_F64_F64_F64:
EXPECT_EQ(result_or_status.status().code(),
absl::StatusCode::kUnimplemented);
break;
// TODO(loislo): Triton implementation needs a fix for dot(inf, 1.0) case.
case PC::ALG_DOT_TF32_TF32_F32_X3:
EXPECT_TRUE(result_or_status.status().ok()); // is supported
EXPECT_FALSE(result_or_status.value()) // but not by triton
<< "wrong result for " << algorithm_;

break;
default:
EXPECT_TRUE(false) << "Uncovered algorithm. Please fix: " << algorithm_;
break;
}
}

INSTANTIATE_TEST_SUITE_P(
AlgorithmsSupportTest, AlgorithmsSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32, PC::ALG_DOT_BF16_BF16_F32_X3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ bool IsDotAlgorithmSupportedByTriton(
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);
switch (algorithm) {
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
if (cuda_compute_capability) {
return true;
}
Expand All @@ -244,9 +244,9 @@ bool IsDotAlgorithmSupportedByTriton(
// TODO(b/326579472): Fix the support of this algorithm and maybe allow it
// here.
case PrecisionConfig::ALG_DOT_F16_F16_F32:
// TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is
// slow to compile. Disable it for now.
case PrecisionConfig::ALG_DOT_F32_F32_F32:
// TODO(b/381244008): Fix the support of this algorithm in Triton for the
// case dot(inf, 1.0) = inf.
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
default:
return false;
}
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ absl::StatusOr<Decision> CreateDotFusion(
algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 ||
algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32 ||
algorithm == PrecisionConfig::ALG_DOT_F32_F32_F32 ||
dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
dot.sparse_operands()) {
return Decision::Allow();
Expand Down

0 comments on commit cee0dce

Please sign in to comment.