Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a few related optimization passes for fp8 gemm custom-calls. #16975

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();

pipeline.AddPass<LayoutNormalization>(&NormalizeLayoutForGpuCustomCalls);

// Layout normalization will create scatters that are not simplified and
// also have unsorted update_window_dims.
pipeline.AddPass<ScatterSimplifier>();

pipeline.AddPass<HostOffloader>(
static_cast<int64_t>(stream_executor::MemoryType::kHost));

Expand Down
69 changes: 61 additions & 8 deletions xla/service/gpu/gpu_compiler_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,19 @@ ENTRY main {
HloOpcode::kAllGatherDone);
}

TEST_F(GpuCompilerTest,
class GpuCompilerTestWithAutotuneDb : public GpuCompilerTest {
public:
static void SetUpTestSuite() {
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
}

static void TearDownTestSuite() { AutotunerUtil::ClearAutotuneResults(); }
};

TEST_F(GpuCompilerTestWithAutotuneDb,
GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) {
auto cc = backend()
.default_stream_executor()
Expand Down Expand Up @@ -455,17 +467,10 @@ ENTRY main {
config.set_replica_count(1);
config.set_num_partitions(1);

// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));
AutotunerUtil::ClearAutotuneResults();
DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
Expand All @@ -485,6 +490,54 @@ ENTRY main {
triton_disabled_module->computation_count());
}

TEST_F(GpuCompilerTestWithAutotuneDb,
CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
if (!cc.IsAtLeastHopper()) {
GTEST_SKIP()
<< "Autotuning results have only been generated for Hopper GPUs";
}
const absl::string_view hlo_string = R"(
HloModule test
ENTRY main {
p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bitcast = bf16[] constant(0.956)
broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={}
ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast)
})";

HloModuleConfig config;
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest();
triton_enabled_debug_options
.set_xla_gpu_require_complete_aot_autotune_results(true);
config.set_debug_options(triton_enabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));

DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true);
config.set_debug_options(triton_disabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_disabled_module,
GetOptimizedModule(std::move(module)));

EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module),
std::move(triton_disabled_module),
ErrorSpec{1e-6, 1e-6}, false));
}

class FloatNormalizationTest : public GpuCompilerTest,
public ::testing::WithParamInterface<
std::pair<PrimitiveType, PrimitiveType>> {};
Expand Down
23 changes: 23 additions & 0 deletions xla/service/gpu/gpu_compiler_test_autotune_db.textproto
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,26 @@ results {
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
}
run_time {
nanos: 1
}
}
}
results {
device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 1
}
}
}
Loading