Skip to content

Commit

Permalink
WIP hipblaslt backporting
Browse files Browse the repository at this point in the history
  • Loading branch information
pemeliya committed Dec 9, 2024
1 parent f059d40 commit 87751fb
Show file tree
Hide file tree
Showing 62 changed files with 4,195 additions and 2,803 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}'
ROCR_RUNTIME_PATH = '%{rocr_runtime_path}'
ROCR_RUNTIME_LIBRARY = '%{rocr_runtime_library}'
VERBOSE = '%{crosstool_verbose}'=='1'
ROCM_AMDGPU_TARGETS = '%{rocm_amdgpu_targets}'

def Log(s):
print('gpus/crosstool: {0}'.format(s))
Expand Down Expand Up @@ -93,6 +94,27 @@ def GetHostCompilerOptions(argv):

return opts

def GetHipccOptions(argv):
"""Collect the -hipcc_options values from argv.
Args:
argv: A list of strings, possibly the argv passed to main().
Returns:
The string that can be passed directly to hipcc.
"""

parser = ArgumentParser()
parser.add_argument('--offload-arch', nargs='*', action='append')
# TODO find a better place for this
parser.add_argument('-gline-tables-only', action='store_true')

args, _ = parser.parse_known_args(argv)

hipcc_opts = ' -gline-tables-only ' if args.gline_tables_only else ''
if args.offload_arch:
hipcc_opts = hipcc_opts + ' '.join(['--offload-arch=' + a for a in sum(args.offload_arch, [])])

return hipcc_opts

def system(cmd):
"""Invokes cmd with os.system().
Expand Down Expand Up @@ -122,6 +144,7 @@ def InvokeHipcc(argv, log=False):
"""

host_compiler_options = GetHostCompilerOptions(argv)
hipcc_compiler_options = GetHipccOptions(argv)
opt_option = GetOptionValue(argv, 'O')
m_options = GetOptionValue(argv, 'm')
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
Expand Down Expand Up @@ -160,7 +183,7 @@ def InvokeHipcc(argv, log=False):
srcs = ' '.join(src_files)
out = ' -o ' + out_file[0]

hipccopts = ' '
hipccopts = hipcc_compiler_options + ' '
# In hip-clang environment, we need to make sure that hip header is included
# before some standard math header like <complex> is included in any source.
# Otherwise, we get build error.
Expand Down Expand Up @@ -214,8 +237,10 @@ def main():
parser.add_argument('-pass-exit-codes', action='store_true')
args, leftover = parser.parse_known_args(sys.argv[1:])

if VERBOSE: print('PWD=' + os.getcwd())
if VERBOSE: print('HIPCC_ENV=' + HIPCC_ENV)
if VERBOSE:
print('PWD=' + os.getcwd())
print('HIPCC_ENV=' + HIPCC_ENV)
print('ROCM_AMDGPU_TARGETS= ' + ROCM_AMDGPU_TARGETS)

if args.x and args.x[0] == 'rocm':
# compilation for GPU objects
Expand Down
8 changes: 8 additions & 0 deletions third_party/gpus/rocm/build_defs.bzl.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def if_rocm_hipblaslt(x):
return select({"//conditions:default": x})
return select({"//conditions:default": []})

def rocm_hipblaslt():
return %{rocm_is_configured} and %{rocm_hipblaslt}

def if_rocm_hipblaslt(x):
if %{rocm_is_configured} and (%{rocm_hipblaslt} == "True"):
return select({"//conditions:default": x})
return select({"//conditions:default": []})

def rocm_library(copts = [], **kwargs):
"""Wrapper over cc_library which adds default ROCm options."""
native.cc_library(copts = rocm_default_copts() + copts, **kwargs)
3 changes: 3 additions & 0 deletions third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ def _create_local_rocm_repository(repository_ctx):
"%{hip_runtime_library}": "amdhip64",
"%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
"%{gcc_host_compiler_path}": str(cc),
"%{rocm_amdgpu_targets}": ",".join(
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
),
},
)

Expand Down
71 changes: 67 additions & 4 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_llvm_enable_invariant_load_metadata(true);
opts.set_xla_llvm_disable_expensive_passes(false);
opts.set_xla_backend_optimization_level(3);
opts.set_xla_gpu_autotune_level(4);
opts.set_xla_gpu_autotune_level(0);
opts.set_xla_cpu_multi_thread_eigen(true);
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
opts.set_xla_gpu_asm_extra_flags("");
Expand Down Expand Up @@ -91,7 +91,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

// Note: CublasLt will be used for FP8 GEMMs regardless of the value of this
// flag.
opts.set_xla_gpu_enable_cublaslt(false);
opts.set_xla_gpu_enable_cublaslt(true);

// TODO(b/258036887): Enable gpu_graph_level=3.
opts.set_xla_gpu_graph_level(2);
Expand Down Expand Up @@ -171,7 +171,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_partitioning_algorithm(
DebugOptions::PARTITIONING_ALGORITHM_NOOP);

opts.set_xla_gpu_enable_triton_gemm(true);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(false);
opts.set_xla_gpu_enable_triton_softmax_fusion(false);
Expand Down Expand Up @@ -202,6 +202,19 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false);
opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true);

opts.set_xla_gpu_autotune_gemm_rtol(0.1f);

opts.set_xla_gpu_redzone_padding_bytes(8 * 1024 * 1024);

// Minimum combined size of matrices in matrix multiplication to
// be rewritten to cuBLAS or Triton kernel call.
// This threshold is a conservative estimate and has been measured
// to be always beneficial (up to generally several times faster)
// on V100 and H100 GPUs. See openxla/xla #9319 for details.
const int64_t kDefaultMinGemmRewriteSize = 100;
opts.set_xla_gpu_gemm_rewrite_size_threshold(kDefaultMinGemmRewriteSize);


return opts;
}

Expand Down Expand Up @@ -287,6 +300,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
};
};

auto float_setter_for =
[debug_options](void (DebugOptions::*member_setter)(float)) {
return [debug_options, member_setter](float value) {
(debug_options->*member_setter)(value);
return true;
};
};

// Custom "sub-parser" lambda for xla_gpu_shape_checks.
auto setter_for_xla_gpu_shape_checks =
[debug_options](const std::string& value) {
Expand Down Expand Up @@ -600,7 +621,35 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
debug_options->xla_gpu_autotune_level(),
"Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; "
"5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also "
"the related flag xla_gpu_autotune_gemm_rtol. Remark that, setting the "
"level to 5 only makes sense if you are sure that the reference (first "
"in the list) solution is numerically CORRECT. Otherwise, the autotuner "
"might discard many other correct solutions based on the failed "
"BufferComparator test."));
flag_list->push_back(tsl::Flag(
"xla_gpu_dump_autotune_results_to",
string_setter_for(&DebugOptions::set_xla_gpu_dump_autotune_results_to),
debug_options->xla_gpu_dump_autotune_results_to(),
"File to write autotune results to. It will be a binary file unless the "
"name ends with .txt or .textproto. Warning: The results are written at "
"every compilation, possibly multiple times per process. This only works "
"on CUDA. In tests, the TEST_UNDECLARED_OUTPUTS_DIR prefix can be used "
"to write to their output directory."));
flag_list->push_back(tsl::Flag(
"xla_gpu_load_autotune_results_from",
string_setter_for(&DebugOptions::set_xla_gpu_load_autotune_results_from),
debug_options->xla_gpu_load_autotune_results_from(),
"File to load autotune results from. It will be considered a binary file "
"unless the name ends with .txt or .textproto. It will be loaded at most "
"once per process. This only works on CUDA. In tests, the TEST_WORKSPACE "
"prefix can be used to load files from their data dependencies."));
flag_list->push_back(tsl::Flag(
"xla_gpu_autotune_gemm_rtol",
float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol),
debug_options->xla_gpu_autotune_gemm_rtol(),
"Relative precision for comparing GEMM solutions vs the reference one"));
flag_list->push_back(tsl::Flag(
"xla_force_host_platform_device_count",
int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
Expand Down Expand Up @@ -1325,6 +1374,20 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int64_setter_for(&DebugOptions::set_xla_debug_buffer_assignment_show_max),
debug_options->xla_debug_buffer_assignment_show_max(),
"Number of buffers to display when debugging the buffer assignment"));
flag_list->push_back(tsl::Flag(
"xla_gpu_gemm_rewrite_size_threshold",
int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold),
debug_options->xla_gpu_gemm_rewrite_size_threshold(),
"Threshold until which elemental dot emitter is preferred for GEMMs "
"(minumum combined number of elements of both matrices "
"in non-batch dimensions to be considered for a rewrite)."));
flag_list->push_back(tsl::Flag(
"xla_gpu_redzone_padding_bytes",
int64_setter_for(&DebugOptions::set_xla_gpu_redzone_padding_bytes),
debug_options->xla_gpu_redzone_padding_bytes(),
"Amount of padding the redzone allocator will put on one side of each "
"buffer it allocates. (So the buffer's total size will be increased by "
"2x this value.)"));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS
Arg<LHLO_Buffer, "", [MemWrite]>:$d,
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$bias,
Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$aux,
Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$workspace,
MHLO_DotDimensionNumbers:$dot_dimension_numbers,
MHLO_PrecisionConfigAttr:$precision_config,
F64Attr:$alpha_real,
Expand Down
132 changes: 52 additions & 80 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ cc_library(
":custom_call_thunk",
":fft_thunk",
":gemm_thunk",
":gpublas_lt_matmul_thunk",
":gpu_asm_opts_util",
":gpu_constants",
":gpu_conv_runner",
Expand Down Expand Up @@ -1359,70 +1360,50 @@ cc_library(
)

cc_library(
name = "cublas_lt_matmul_thunk",
srcs = if_cuda_is_configured(["cublas_lt_matmul_thunk.cc"]) + if_rocm_is_configured([
"cublas_lt_matmul_thunk.cc",
]),
hdrs = if_cuda_is_configured(["cublas_lt_matmul_thunk.h"]) + if_rocm_is_configured([
"cublas_lt_matmul_thunk.h",
]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
visibility = ["//visibility:public"],
deps = if_gpu_is_configured([
name = "gpublas_lt_matmul_thunk",
srcs = ["gpublas_lt_matmul_thunk.cc"],
hdrs = ["gpublas_lt_matmul_thunk.h"],
deps = [
":matmul_utils",
":thunk",
"//xla/service:buffer_assignment",
"//xla:status",
"//xla/stream_executor:device_memory",
"//xla/stream_executor",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
]) + if_cuda_is_configured([
"//xla/stream_executor/cuda:cublas_lt_header",
"//xla/stream_executor/cuda:cublas_plugin",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:hipblas_lt_header",
]),
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/stream_executor:device_memory",
"//tensorflow/compiler/xla/stream_executor:stream_executor_headers",
"//tensorflow/tsl/platform:logging",
],
)

cc_library(
name = "gemm_algorithm_picker",
srcs = if_cuda_is_configured(["gemm_algorithm_picker.cc"]),
hdrs = if_cuda_is_configured(["gemm_algorithm_picker.h"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
visibility = ["//visibility:public"],
deps = if_cuda_is_configured([
srcs = ["gemm_algorithm_picker.cc"],
hdrs = ["gemm_algorithm_picker.h"],
deps = [
":backend_configs_cc",
":buffer_comparator",
":gemm_thunk",
":gpu_asm_opts_util",
":gpu_conv_runner",
":gpu_executable",
":ir_emission_utils",
":matmul_utils",
":stream_executor_util",
":matmul_utils",
":autotuner_compile_util",
":autotuner_util",
"@com_google_absl//absl/strings",
"//xla:autotune_results_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_pass",
"//xla:status_macros",
"//xla/stream_executor",
"//xla/stream_executor:blas",
"//xla/stream_executor/cuda:cublas_lt_header",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor/gpu:redzone_allocator",
"//xla:util",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logger",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
"//xla:autotuning_proto_cc",
"@local_tsl//tsl/util/proto:proto_utils",
]),
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/util/proto:proto_utils",
"//tensorflow/tsl/platform:logger",
"//tensorflow/tsl/platform:logging",
"//tensorflow/tsl/platform:statusor",
"//tensorflow/tsl/protobuf:autotuning_proto_cc",
"//tensorflow/compiler/xla/stream_executor:blas",
"//tensorflow/compiler/xla/stream_executor:device_memory",
"//tensorflow/compiler/xla/stream_executor:device_memory_allocator",
"//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator",
"@com_google_absl//absl/types:optional",
],
)

cc_library(
Expand Down Expand Up @@ -1515,41 +1496,32 @@ cc_library(
name = "matmul_utils",
srcs = ["matmul_utils.cc"],
hdrs = ["matmul_utils.h"],
compatible_with = get_compatible_with_portable(),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
visibility = ["//visibility:public"],
compatible_with = get_compatible_with_cloud(),
defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":backend_configs_cc",
":ir_emission_utils",
"//xla:shape_util",
"//xla:status_macros",
"//xla:statusor",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/mlir_hlo:lhlo_gpu",
"//xla/stream_executor",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/mlir_hlo",
"//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu",
"//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt",
"//tensorflow/compiler/xla/stream_executor:stream_executor_headers",
"//tensorflow/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@com_google_absl//absl/types:any",
] + if_cuda_is_configured([
"//xla/stream_executor/cuda:cublas_lt_header",
"//xla/stream_executor/cuda:cublas_plugin",
"@local_tsl//tsl/platform:tensor_float_32_hdr_lib",
"//xla/stream_executor:host_or_device_scalar",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:hipblas_lt_header",
"//xla/stream_executor/rocm:amdhipblaslt_plugin",
"//xla/stream_executor:host_or_device_scalar",
"//xla/stream_executor/platform:dso_loader",
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
"//tensorflow/compiler/xla/stream_executor:host_or_device_scalar",
"//tensorflow/compiler/xla/stream_executor:scratch_allocator",
]) + if_static([
"@local_tsl//tsl/platform:tensor_float_32_utils",
"//tensorflow/tsl/platform:tensor_float_32_utils",
]),
)

Expand Down
Loading

0 comments on commit 87751fb

Please sign in to comment.