Skip to content

Commit

Permalink
Fix internal test and roll-forward PR #16975: Add a few related optim…
Browse files Browse the repository at this point in the history
…ization passes for fp8 gemm custom-calls.

Reverts fd64718

PiperOrigin-RevId: 686037932
  • Loading branch information
derdrdirk authored and Google-ML-Automation committed Oct 15, 2024
1 parent c4c0f4a commit 27f2d9a
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32) {
}
)";
const std::string pattern = R"(
CHECK: %convert.2.0 = bf16[
CHECK: %convert.3.0 = bf16[
CHECK: %convert.4.0 = bf16[
CHECK: %convert.5.0 = bf16[
CHECK: "algorithm":"ALG_UNSET"
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
Expand Down
6 changes: 6 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();

pipeline.AddPass<LayoutNormalization>(&NormalizeLayoutForGpuCustomCalls);

// Layout normalization will create scatters that are not simplified and
// also have unsorted update_window_dims.
pipeline.AddPass<ScatterSimplifier>();

pipeline.AddPass<HostOffloader>(
static_cast<int64_t>(stream_executor::MemoryType::kHost));

Expand Down
69 changes: 61 additions & 8 deletions xla/service/gpu/gpu_compiler_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,19 @@ ENTRY main {
HloOpcode::kAllGatherDone);
}

TEST_F(GpuCompilerTest,
class GpuCompilerTestWithAutotuneDb : public GpuCompilerTest {
public:
static void SetUpTestSuite() {
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
}

static void TearDownTestSuite() { AutotunerUtil::ClearAutotuneResults(); }
};

TEST_F(GpuCompilerTestWithAutotuneDb,
GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) {
auto cc = backend()
.default_stream_executor()
Expand Down Expand Up @@ -456,17 +468,10 @@ ENTRY main {
config.set_replica_count(1);
config.set_num_partitions(1);

// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));
AutotunerUtil::ClearAutotuneResults();
DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
Expand All @@ -486,6 +491,54 @@ ENTRY main {
triton_disabled_module->computation_count());
}

TEST_F(GpuCompilerTestWithAutotuneDb,
CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
if (!cc.IsAtLeastHopper()) {
GTEST_SKIP()
<< "Autotuning results have only been generated for Hopper GPUs";
}
const absl::string_view hlo_string = R"(
HloModule test
ENTRY main {
p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bitcast = bf16[] constant(0.956)
broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={}
ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast)
})";

HloModuleConfig config;
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest();
triton_enabled_debug_options
.set_xla_gpu_require_complete_aot_autotune_results(true);
config.set_debug_options(triton_enabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));

DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true);
config.set_debug_options(triton_disabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_disabled_module,
GetOptimizedModule(std::move(module)));

EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module),
std::move(triton_disabled_module),
ErrorSpec{1e-6, 1e-6}, false));
}

class FloatNormalizationTest : public GpuCompilerTest,
public ::testing::WithParamInterface<
std::pair<PrimitiveType, PrimitiveType>> {};
Expand Down
35 changes: 35 additions & 0 deletions xla/service/gpu/gpu_compiler_test_autotune_db.textproto
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,38 @@ results {
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
}
run_time {
nanos: 1
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 1
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,12288]{1,0} bitcast(f8e4m3fn[12288,4096]{0,1} tmp_0)\n tmp_2 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_3 = bf16[12288,16384]{1,0} dot(f8e4m3fn[4096,12288]{1,0} tmp_1, f8e4m3fn[4096,16384]{0,1} tmp_2), lhs_contracting_dims={0}, rhs_contracting_dims={0}\n tmp_4 = bf16[] constant({...})\n tmp_5 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_4), dimensions={}\n ROOT tmp_6 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_3, bf16[12288,16384]{1,0} tmp_5)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 1
}
}
}
2 changes: 1 addition & 1 deletion xla/service/gpu/tests/dot_bf16.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ENTRY %computation1 {
// -----

// CHECK-SM70: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(f32[1536,6144]{1,0} {{.*}}, f32[32,1536]{1,0} {{.*}}), custom_call_target="__cublas$gemm"
// CHECK-SM80: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(bf16[1536,6144]{1,0} %convert.1.0, bf16[32,1536]{1,0} %b.1), custom_call_target="__cublas$gemm"
// CHECK-SM80: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(bf16[1536,6144]{1,0} %convert.2.0, bf16[32,1536]{1,0} %b.1), custom_call_target="__cublas$gemm"

HloModule module2

Expand Down

0 comments on commit 27f2d9a

Please sign in to comment.