Skip to content

Commit

Permalink
PR #16975: Add a few related optimization passes for fp8 gemm custom-…
Browse files Browse the repository at this point in the history
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <[email protected]>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684532401
  • Loading branch information
elfiegg authored and Google-ML-Automation committed Oct 10, 2024
1 parent d05087b commit 0358845
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 8 deletions.
6 changes: 6 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();

pipeline.AddPass<LayoutNormalization>(&NormalizeLayoutForGpuCustomCalls);

// Layout normalization will create scatters that are not simplified and
// also have unsorted update_window_dims.
pipeline.AddPass<ScatterSimplifier>();

pipeline.AddPass<HostOffloader>(
static_cast<int64_t>(stream_executor::MemoryType::kHost));

Expand Down
69 changes: 61 additions & 8 deletions xla/service/gpu/gpu_compiler_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,19 @@ ENTRY main {
HloOpcode::kAllGatherDone);
}

TEST_F(GpuCompilerTest,
class GpuCompilerTestWithAutotuneDb : public GpuCompilerTest {
public:
static void SetUpTestSuite() {
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
}

static void TearDownTestSuite() { AutotunerUtil::ClearAutotuneResults(); }
};

TEST_F(GpuCompilerTestWithAutotuneDb,
GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) {
auto cc = backend()
.default_stream_executor()
Expand Down Expand Up @@ -455,17 +467,10 @@ ENTRY main {
config.set_replica_count(1);
config.set_num_partitions(1);

// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));
AutotunerUtil::ClearAutotuneResults();
DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
Expand All @@ -485,6 +490,54 @@ ENTRY main {
triton_disabled_module->computation_count());
}

TEST_F(GpuCompilerTestWithAutotuneDb,
CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
if (!cc.IsAtLeastHopper()) {
GTEST_SKIP()
<< "Autotuning results have only been generated for Hopper GPUs";
}
const absl::string_view hlo_string = R"(
HloModule test
ENTRY main {
p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bitcast = bf16[] constant(0.956)
broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={}
ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast)
})";

HloModuleConfig config;
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest();
triton_enabled_debug_options
.set_xla_gpu_require_complete_aot_autotune_results(true);
config.set_debug_options(triton_enabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));

DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true);
config.set_debug_options(triton_disabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_disabled_module,
GetOptimizedModule(std::move(module)));

EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module),
std::move(triton_disabled_module),
ErrorSpec{1e-6, 1e-6}, false));
}

class FloatNormalizationTest : public GpuCompilerTest,
public ::testing::WithParamInterface<
std::pair<PrimitiveType, PrimitiveType>> {};
Expand Down
23 changes: 23 additions & 0 deletions xla/service/gpu/gpu_compiler_test_autotune_db.textproto
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,26 @@ results {
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
}
run_time {
nanos: 1
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 1
}
}
}

0 comments on commit 0358845

Please sign in to comment.