Skip to content

Commit

Permalink
supported for hipblaslt with autotuning
Browse files Browse the repository at this point in the history
amdhipblaslt_plugin compile

gpu_executable compiled

gpublas_lt thunk builds

unit test compiles

more adaptions for workspace buffer and mhlo

starting autotuner backport

updated picker

adding autotuner support

autotuner compiles

autotuner update

remaining autotuner updates

remaining build fixes

added missing tf32 support

gpu_blas_lt_gemm_runner

disable check

fix location of header files

forward gemm calls to gpu blas lt runner

minor fix

explicit instantiation of ThenBlasGemm

use default as fallback algorithm
  • Loading branch information
pemeliya committed Dec 6, 2024
1 parent 0c0fe34 commit 60940c5
Show file tree
Hide file tree
Showing 81 changed files with 6,773 additions and 3,536 deletions.
2 changes: 1 addition & 1 deletion build_rocm_python3
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ done
shift "$((OPTIND-1))"

# First positional argument (if any) specifies the ROCM_INSTALL_DIR
ROCM_INSTALL_DIR=/opt/rocm-6.2.0
ROCM_INSTALL_DIR=$(realpath /opt/rocm)
if [[ -n $1 ]]; then
ROCM_INSTALL_DIR=$1
fi
Expand Down
71 changes: 67 additions & 4 deletions tensorflow/compiler/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_llvm_enable_invariant_load_metadata(true);
opts.set_xla_llvm_disable_expensive_passes(false);
opts.set_xla_backend_optimization_level(3);
opts.set_xla_gpu_autotune_level(4);
opts.set_xla_gpu_autotune_level(0);
opts.set_xla_cpu_multi_thread_eigen(true);
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
opts.set_xla_gpu_asm_extra_flags("");
Expand Down Expand Up @@ -74,7 +74,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

// Note: CublasLt will be used for FP8 GEMMs regardless of the value of this
// flag.
opts.set_xla_gpu_enable_cublaslt(false);
opts.set_xla_gpu_enable_cublaslt(true);

// TODO(b/258036887): Enable once CUDA Graphs are fully supported.
opts.set_xla_gpu_cuda_graph_level(0);
Expand Down Expand Up @@ -122,7 +122,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_partitioning_algorithm(
DebugOptions::PARTITIONING_ALGORITHM_NOOP);

opts.set_xla_gpu_enable_triton_gemm(true);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(false);

Expand All @@ -131,6 +131,19 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(false);

opts.set_xla_gpu_collective_inflation_factor(1);

opts.set_xla_gpu_autotune_gemm_rtol(0.1f);

opts.set_xla_gpu_redzone_padding_bytes(8 * 1024 * 1024);

// Minimum combined size of matrices in matrix multiplication to
// be rewritten to cuBLAS or Triton kernel call.
// This threshold is a conservative estimate and has been measured
// to be always beneficial (up to generally several times faster)
// on V100 and H100 GPUs. See openxla/xla #9319 for details.
const int64_t kDefaultMinGemmRewriteSize = 100;
opts.set_xla_gpu_gemm_rewrite_size_threshold(kDefaultMinGemmRewriteSize);

return opts;
}

Expand Down Expand Up @@ -209,6 +222,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
};
};

auto float_setter_for =
[debug_options](void (DebugOptions::*member_setter)(float)) {
return [debug_options, member_setter](float value) {
(debug_options->*member_setter)(value);
return true;
};
};

// Custom "sub-parser" lambda for xla_gpu_shape_checks.
auto setter_for_xla_gpu_shape_checks =
[debug_options](const std::string& value) {
Expand Down Expand Up @@ -527,7 +548,35 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
debug_options->xla_gpu_autotune_level(),
"Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; "
"5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also "
"the related flag xla_gpu_autotune_gemm_rtol. Remark that, setting the "
"level to 5 only makes sense if you are sure that the reference (first "
"in the list) solution is numerically CORRECT. Otherwise, the autotuner "
"might discard many other correct solutions based on the failed "
"BufferComparator test."));
flag_list->push_back(tsl::Flag(
"xla_gpu_dump_autotune_results_to",
string_setter_for(&DebugOptions::set_xla_gpu_dump_autotune_results_to),
debug_options->xla_gpu_dump_autotune_results_to(),
"File to write autotune results to. It will be a binary file unless the "
"name ends with .txt or .textproto. Warning: The results are written at "
"every compilation, possibly multiple times per process. This only works "
"on CUDA. In tests, the TEST_UNDECLARED_OUTPUTS_DIR prefix can be used "
"to write to their output directory."));
flag_list->push_back(tsl::Flag(
"xla_gpu_load_autotune_results_from",
string_setter_for(&DebugOptions::set_xla_gpu_load_autotune_results_from),
debug_options->xla_gpu_load_autotune_results_from(),
"File to load autotune results from. It will be considered a binary file "
"unless the name ends with .txt or .textproto. It will be loaded at most "
"once per process. This only works on CUDA. In tests, the TEST_WORKSPACE "
"prefix can be used to load files from their data dependencies."));
flag_list->push_back(tsl::Flag(
"xla_gpu_autotune_gemm_rtol",
float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol),
debug_options->xla_gpu_autotune_gemm_rtol(),
"Relative precision for comparing GEMM solutions vs the reference one"));
flag_list->push_back(tsl::Flag(
"xla_force_host_platform_device_count",
int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
Expand Down Expand Up @@ -823,6 +872,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(&DebugOptions::set_xla_gpu_enable_cublaslt),
debug_options->xla_gpu_enable_cublaslt(),
"Use cuBLASLt for GEMMs when possible."));
flag_list->push_back(tsl::Flag(
"xla_gpu_gemm_rewrite_size_threshold",
int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold),
debug_options->xla_gpu_gemm_rewrite_size_threshold(),
"Threshold until which elemental dot emitter is preferred for GEMMs "
"(minumum combined number of elements of both matrices "
"in non-batch dimensions to be considered for a rewrite)."));
flag_list->push_back(tsl::Flag(
"xla_gpu_cuda_graph_level",
int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_level),
Expand Down Expand Up @@ -994,6 +1050,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_triton_gemm_any(),
"Use Triton-based matrix multiplication for any GEMM it "
"supports without filtering only faster ones."));
flag_list->push_back(tsl::Flag(
"xla_gpu_redzone_padding_bytes",
int64_setter_for(&DebugOptions::set_xla_gpu_redzone_padding_bytes),
debug_options->xla_gpu_redzone_padding_bytes(),
"Amount of padding the redzone allocator will put on one side of each "
"buffer it allocates. (So the buffer's total size will be increased by "
"2x this value.)"));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS
Arg<LHLO_Buffer, "", [MemWrite]>:$d,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$bias,
Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$aux,
Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$workspace,
MHLO_DotDimensionNumbers:$dot_dimension_numbers,
MHLO_PrecisionConfigAttr:$precision_config,
F64Attr:$alpha_real,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/computation_placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
ComputationPlacerCreationFunction creation_function) {
absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
auto* computation_placers = GetPlatformComputationPlacers();
CHECK(computation_placers->find(platform_id) == computation_placers->end());
// CHECK(computation_placers->find(platform_id) == computation_placers->end());
(*computation_placers)[platform_id].creation_function = creation_function;
}

Expand Down
Loading

0 comments on commit 60940c5

Please sign in to comment.