Skip to content

Commit

Permalink
Add a few related optimization pass for fp8 gemm rerwriter.
Browse files Browse the repository at this point in the history
  • Loading branch information
elfiegg committed Sep 17, 2024
1 parent 9be7aca commit b85df42
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();

pipeline.AddPass<LayoutNormalization>(&NormalizeLayoutForGpuCustomCalls);

// Layout normalization will create scatters that are not simplified and
// also have unsorted update_window_dims.
pipeline.AddPass<ScatterSimplifier>();

pipeline.AddPass<HostOffloadLegalize>(
static_cast<int64_t>(stream_executor::MemoryType::kHost),
/* after_layout= */ true);
Expand Down
56 changes: 56 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,62 @@ ENTRY main {
triton_disabled_module->computation_count());
}

TEST_F(GpuCompilerTest, CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
if (!cc.IsAtLeastAmpere()) {
GTEST_SKIP() << "Autotuning results have only been generated for Ampere "
<< "and Hopper GPUs";
}
const absl::string_view hlo_string = R"(
HloModule test
ENTRY main {
p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bitcast = bf16[] constant(0.956)
broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={}
ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast)
})";

HloModuleConfig config;
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest();
triton_enabled_debug_options
.set_xla_gpu_require_complete_aot_autotune_results(true);
config.set_debug_options(triton_enabled_debug_options);

// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));

AutotunerUtil::ClearAutotuneResults();
DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true);
config.set_debug_options(triton_disabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_disabled_module,
GetOptimizedModule(std::move(module)));

EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module),
std::move(triton_disabled_module),
ErrorSpec{1e-6, 1e-6}, false));
}

class FloatNormalizationTest : public GpuCompilerTest,
public ::testing::WithParamInterface<
std::pair<PrimitiveType, PrimitiveType>> {};
Expand Down
23 changes: 23 additions & 0 deletions xla/service/gpu/gpu_compiler_test_autotune_db.textproto
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,26 @@ results {
}
}
}
results {
device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB"
hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
}
run_time {
nanos: 1
}
}
}
results {
device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB"
hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 1
}
}
}

0 comments on commit b85df42

Please sign in to comment.