From 87751fba66f7c7f2d8ae89c37b5c39211235e1fc Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Mon, 9 Dec 2024 05:53:57 -0600 Subject: [PATCH] WIP hipblaslt backporting --- .../bin/crosstool_wrapper_driver_rocm.tpl | 31 +- third_party/gpus/rocm/build_defs.bzl.tpl | 8 + third_party/gpus/rocm_configure.bzl | 3 + third_party/xla/xla/debug_options_flags.cc | 71 +- .../xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td | 1 + third_party/xla/xla/service/gpu/BUILD | 132 +- .../xla/xla/service/gpu/amdgpu_compiler.cc | 4 +- .../xla/xla/service/gpu/buffer_comparator.cc | 1160 ++--------------- .../xla/service/gpu/buffer_comparator.cu.cc | 181 +++ .../xla/xla/service/gpu/buffer_comparator.h | 48 +- .../xla/xla/service/gpu/cublas_cudnn.cc | 6 - .../xla/xla/service/gpu/cublas_cudnn.h | 6 - .../xla/service/gpu/cublas_lt_matmul_thunk.h | 89 -- .../xla/service/gpu/gemm_algorithm_picker.cc | 668 +++++----- .../xla/service/gpu/gemm_algorithm_picker.h | 72 +- .../xla/xla/service/gpu/gemm_rewriter.cc | 177 ++- .../xla/xla/service/gpu/gemm_rewriter.h | 7 +- third_party/xla/xla/service/gpu/gemm_thunk.cc | 4 +- .../xla/xla/service/gpu/gpu_compiler.cc | 4 + .../service/gpu/gpublas_lt_matmul_thunk.cc | 184 +++ .../xla/service/gpu/gpublas_lt_matmul_thunk.h | 73 ++ .../xla/xla/service/gpu/ir_emission_utils.cc | 43 +- .../xla/xla/service/gpu/ir_emission_utils.h | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 15 +- .../xla/xla/service/gpu/matmul_utils.cc | 1050 ++++++--------- .../xla/xla/service/gpu/matmul_utils.h | 300 ++--- .../service/gpu/runtime/cublas_lt_matmul.cc | 2 +- .../xla/xla/service/gpu/runtime/gemm.cc | 10 +- .../xla/xla/service/gpu/runtime/support.h | 2 +- .../xla/service/gpu/stream_executor_util.cc | 148 ++- .../xla/service/gpu/stream_executor_util.h | 3 +- third_party/xla/xla/service/gpu/tests/BUILD | 16 + .../service/gpu/tests/gpu_hlo_runner_test.cc | 129 ++ third_party/xla/xla/stream_executor/blas.cc | 21 + third_party/xla/xla/stream_executor/blas.h | 36 + third_party/xla/xla/stream_executor/gpu/BUILD | 40 +- .../xla/xla/stream_executor/gpu/BUILD.rej | 92 ++ .../xla/stream_executor/gpu/gpu_blas_lt.cc | 285 ++++ .../xla/xla/stream_executor/gpu/gpu_blas_lt.h | 278 ++++ .../gpu/gpu_blas_lt_gemm_runner.cc | 341 +++++ .../gpu/gpu_blas_lt_gemm_runner.h | 260 ++++ .../xla/xla/stream_executor/gpu/gpu_driver.h | 2 +- .../xla/xla/stream_executor/gpu/gpu_kernel.h | 6 +- .../stream_executor/gpu/redzone_allocator.cc | 197 +-- .../stream_executor/gpu/redzone_allocator.h | 39 +- .../gpu/redzone_allocator_kernel.h | 39 + .../gpu/redzone_allocator_kernel_cuda.cc | 146 +++ .../gpu/redzone_allocator_kernel_rocm.cu.cc | 49 + .../gpu/redzone_allocator_test.cc | 154 +++ .../xla/xla/stream_executor/kernel_spec.cc | 12 + .../xla/xla/stream_executor/kernel_spec.h | 21 + .../xla/xla/stream_executor/rocm/BUILD | 1 + .../stream_executor/rocm/hip_blas_lt.cu.cc | 58 + .../xla/xla/stream_executor/rocm/rocm_blas.cc | 14 +- .../xla/xla/stream_executor/rocm/rocm_blas.h | 4 +- .../xla/stream_executor/rocm/rocm_driver.cc | 37 +- .../rocm/rocm_driver_wrapper.h | 1 + .../stream_executor/rocm/rocm_gpu_executor.cc | 40 +- third_party/xla/xla/stream_executor/stream.cc | 111 ++ .../stream_executor/stream_executor_pimpl.h | 17 + .../mhlo_to_lhlo_with_xla.cc | 21 +- third_party/xla/xla/xla.proto | 28 +- 62 files changed, 4195 insertions(+), 2803 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/buffer_comparator.cu.cc delete mode 100644 third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h create mode 100644 third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc create mode 100644 third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h create mode 100644 third_party/xla/xla/service/gpu/tests/gpu_hlo_runner_test.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/BUILD.rej create mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h create mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h create mode 100644 third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h create mode 100644 third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc create mode 100644 third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc create mode 100644 third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cu.cc diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 77ec948af32c6e..4b3c2b289300ca 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -32,6 +32,7 @@ HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}' ROCR_RUNTIME_PATH = '%{rocr_runtime_path}' ROCR_RUNTIME_LIBRARY = '%{rocr_runtime_library}' VERBOSE = '%{crosstool_verbose}'=='1' +ROCM_AMDGPU_TARGETS = '%{rocm_amdgpu_targets}' def Log(s): print('gpus/crosstool: {0}'.format(s)) @@ -93,6 +94,27 @@ def GetHostCompilerOptions(argv): return opts +def GetHipccOptions(argv): + """Collect the -hipcc_options values from argv. + Args: + argv: A list of strings, possibly the argv passed to main(). + Returns: + The string that can be passed directly to hipcc. + """ + + parser = ArgumentParser() + parser.add_argument('--offload-arch', nargs='*', action='append') + # TODO find a better place for this + parser.add_argument('-gline-tables-only', action='store_true') + + args, _ = parser.parse_known_args(argv) + + hipcc_opts = ' -gline-tables-only ' if args.gline_tables_only else '' + if args.offload_arch: + hipcc_opts = hipcc_opts + ' '.join(['--offload-arch=' + a for a in sum(args.offload_arch, [])]) + + return hipcc_opts + def system(cmd): """Invokes cmd with os.system(). @@ -122,6 +144,7 @@ def InvokeHipcc(argv, log=False): """ host_compiler_options = GetHostCompilerOptions(argv) + hipcc_compiler_options = GetHipccOptions(argv) opt_option = GetOptionValue(argv, 'O') m_options = GetOptionValue(argv, 'm') m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) @@ -160,7 +183,7 @@ def InvokeHipcc(argv, log=False): srcs = ' '.join(src_files) out = ' -o ' + out_file[0] - hipccopts = ' ' + hipccopts = hipcc_compiler_options + ' ' # In hip-clang environment, we need to make sure that hip header is included # before some standard math header like is included in any source. # Otherwise, we get build error. @@ -214,8 +237,10 @@ def main(): parser.add_argument('-pass-exit-codes', action='store_true') args, leftover = parser.parse_known_args(sys.argv[1:]) - if VERBOSE: print('PWD=' + os.getcwd()) - if VERBOSE: print('HIPCC_ENV=' + HIPCC_ENV) + if VERBOSE: + print('PWD=' + os.getcwd()) + print('HIPCC_ENV=' + HIPCC_ENV) + print('ROCM_AMDGPU_TARGETS= ' + ROCM_AMDGPU_TARGETS) if args.x and args.x[0] == 'rocm': # compilation for GPU objects diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 2b4595bb222885..aae918d098272b 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -56,6 +56,14 @@ def if_rocm_hipblaslt(x): return select({"//conditions:default": x}) return select({"//conditions:default": []}) +def rocm_hipblaslt(): + return %{rocm_is_configured} and %{rocm_hipblaslt} + +def if_rocm_hipblaslt(x): + if %{rocm_is_configured} and (%{rocm_hipblaslt} == "True"): + return select({"//conditions:default": x}) + return select({"//conditions:default": []}) + def rocm_library(copts = [], **kwargs): """Wrapper over cc_library which adds default ROCm options.""" native.cc_library(copts = rocm_default_copts() + copts, **kwargs) diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index ab36300911047c..cef8134239d601 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -739,6 +739,9 @@ def _create_local_rocm_repository(repository_ctx): "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), }, ) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index d64c75b728e8a2..023e8113d1bc0d 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -41,7 +41,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_llvm_enable_invariant_load_metadata(true); opts.set_xla_llvm_disable_expensive_passes(false); opts.set_xla_backend_optimization_level(3); - opts.set_xla_gpu_autotune_level(4); + opts.set_xla_gpu_autotune_level(0); opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_gpu_asm_extra_flags(""); @@ -91,7 +91,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // Note: CublasLt will be used for FP8 GEMMs regardless of the value of this // flag. - opts.set_xla_gpu_enable_cublaslt(false); + opts.set_xla_gpu_enable_cublaslt(true); // TODO(b/258036887): Enable gpu_graph_level=3. opts.set_xla_gpu_graph_level(2); @@ -171,7 +171,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_partitioning_algorithm( DebugOptions::PARTITIONING_ALGORITHM_NOOP); - opts.set_xla_gpu_enable_triton_gemm(true); + opts.set_xla_gpu_enable_triton_gemm(false); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); opts.set_xla_gpu_enable_triton_softmax_fusion(false); @@ -202,6 +202,19 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false); opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true); + opts.set_xla_gpu_autotune_gemm_rtol(0.1f); + + opts.set_xla_gpu_redzone_padding_bytes(8 * 1024 * 1024); + + // Minimum combined size of matrices in matrix multiplication to + // be rewritten to cuBLAS or Triton kernel call. + // This threshold is a conservative estimate and has been measured + // to be always beneficial (up to generally several times faster) + // on V100 and H100 GPUs. See openxla/xla #9319 for details. + const int64_t kDefaultMinGemmRewriteSize = 100; + opts.set_xla_gpu_gemm_rewrite_size_threshold(kDefaultMinGemmRewriteSize); + + return opts; } @@ -287,6 +300,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, }; }; + auto float_setter_for = + [debug_options](void (DebugOptions::*member_setter)(float)) { + return [debug_options, member_setter](float value) { + (debug_options->*member_setter)(value); + return true; + }; + }; + // Custom "sub-parser" lambda for xla_gpu_shape_checks. auto setter_for_xla_gpu_shape_checks = [debug_options](const std::string& value) { @@ -600,7 +621,35 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), debug_options->xla_gpu_autotune_level(), "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = " - "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check.")); + "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; " + "5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also " + "the related flag xla_gpu_autotune_gemm_rtol. Remark that, setting the " + "level to 5 only makes sense if you are sure that the reference (first " + "in the list) solution is numerically CORRECT. Otherwise, the autotuner " + "might discard many other correct solutions based on the failed " + "BufferComparator test.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_dump_autotune_results_to", + string_setter_for(&DebugOptions::set_xla_gpu_dump_autotune_results_to), + debug_options->xla_gpu_dump_autotune_results_to(), + "File to write autotune results to. It will be a binary file unless the " + "name ends with .txt or .textproto. Warning: The results are written at " + "every compilation, possibly multiple times per process. This only works " + "on CUDA. In tests, the TEST_UNDECLARED_OUTPUTS_DIR prefix can be used " + "to write to their output directory.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_load_autotune_results_from", + string_setter_for(&DebugOptions::set_xla_gpu_load_autotune_results_from), + debug_options->xla_gpu_load_autotune_results_from(), + "File to load autotune results from. It will be considered a binary file " + "unless the name ends with .txt or .textproto. It will be loaded at most " + "once per process. This only works on CUDA. In tests, the TEST_WORKSPACE " + "prefix can be used to load files from their data dependencies.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_autotune_gemm_rtol", + float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol), + debug_options->xla_gpu_autotune_gemm_rtol(), + "Relative precision for comparing GEMM solutions vs the reference one")); flag_list->push_back(tsl::Flag( "xla_force_host_platform_device_count", int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count), @@ -1325,6 +1374,20 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int64_setter_for(&DebugOptions::set_xla_debug_buffer_assignment_show_max), debug_options->xla_debug_buffer_assignment_show_max(), "Number of buffers to display when debugging the buffer assignment")); + flag_list->push_back(tsl::Flag( + "xla_gpu_gemm_rewrite_size_threshold", + int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold), + debug_options->xla_gpu_gemm_rewrite_size_threshold(), + "Threshold until which elemental dot emitter is preferred for GEMMs " + "(minumum combined number of elements of both matrices " + "in non-batch dimensions to be considered for a rewrite).")); + flag_list->push_back(tsl::Flag( + "xla_gpu_redzone_padding_bytes", + int64_setter_for(&DebugOptions::set_xla_gpu_redzone_padding_bytes), + debug_options->xla_gpu_redzone_padding_bytes(), + "Amount of padding the redzone allocator will put on one side of each " + "buffer it allocates. (So the buffer's total size will be increased by " + "2x this value.)")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 44de7608229ec1..3da4c7e0dbb14e 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -183,6 +183,7 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS Arg:$d, Arg, "", [MemRead]>:$bias, Arg, "", [MemRead, MemWrite]>:$aux, + Arg, "", [MemRead, MemWrite]>:$workspace, MHLO_DotDimensionNumbers:$dot_dimension_numbers, MHLO_PrecisionConfigAttr:$precision_config, F64Attr:$alpha_real, diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ed6169d71fcd94..f369ade39ea83b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -941,6 +941,7 @@ cc_library( ":custom_call_thunk", ":fft_thunk", ":gemm_thunk", + ":gpublas_lt_matmul_thunk", ":gpu_asm_opts_util", ":gpu_constants", ":gpu_conv_runner", @@ -1359,70 +1360,50 @@ cc_library( ) cc_library( - name = "cublas_lt_matmul_thunk", - srcs = if_cuda_is_configured(["cublas_lt_matmul_thunk.cc"]) + if_rocm_is_configured([ - "cublas_lt_matmul_thunk.cc", - ]), - hdrs = if_cuda_is_configured(["cublas_lt_matmul_thunk.h"]) + if_rocm_is_configured([ - "cublas_lt_matmul_thunk.h", - ]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ + name = "gpublas_lt_matmul_thunk", + srcs = ["gpublas_lt_matmul_thunk.cc"], + hdrs = ["gpublas_lt_matmul_thunk.h"], + deps = [ ":matmul_utils", ":thunk", - "//xla/service:buffer_assignment", - "//xla:status", - "//xla/stream_executor:device_memory", - "//xla/stream_executor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ]) + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cublas_lt_header", - "//xla/stream_executor/cuda:cublas_plugin", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:hipblas_lt_header", - ]), + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/stream_executor:device_memory", + "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", + "//tensorflow/tsl/platform:logging", + ], ) cc_library( name = "gemm_algorithm_picker", - srcs = if_cuda_is_configured(["gemm_algorithm_picker.cc"]), - hdrs = if_cuda_is_configured(["gemm_algorithm_picker.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + srcs = ["gemm_algorithm_picker.cc"], + hdrs = ["gemm_algorithm_picker.h"], + deps = [ ":backend_configs_cc", ":buffer_comparator", - ":gemm_thunk", - ":gpu_asm_opts_util", ":gpu_conv_runner", + ":gpu_executable", ":ir_emission_utils", - ":matmul_utils", ":stream_executor_util", + ":matmul_utils", + ":autotuner_compile_util", ":autotuner_util", - "@com_google_absl//absl/strings", - "//xla:autotune_results_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla:status_macros", - "//xla/stream_executor", - "//xla/stream_executor:blas", - "//xla/stream_executor/cuda:cublas_lt_header", - "//xla/stream_executor/cuda:cublas_plugin", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor/gpu:redzone_allocator", - "//xla:util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logger", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "//xla:autotuning_proto_cc", - "@local_tsl//tsl/util/proto:proto_utils", - ]), + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/util/proto:proto_utils", + "//tensorflow/tsl/platform:logger", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:statusor", + "//tensorflow/tsl/protobuf:autotuning_proto_cc", + "//tensorflow/compiler/xla/stream_executor:blas", + "//tensorflow/compiler/xla/stream_executor:device_memory", + "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", + "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", + "@com_google_absl//absl/types:optional", + ], ) cc_library( @@ -1515,41 +1496,32 @@ cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], hdrs = ["matmul_utils.h"], - compatible_with = get_compatible_with_portable(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_cloud(), + defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", ":ir_emission_utils", - "//xla:shape_util", - "//xla:status_macros", - "//xla:statusor", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/stream_executor", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt", + "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/types:any", ] + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cublas_lt_header", - "//xla/stream_executor/cuda:cublas_plugin", - "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - "//xla/stream_executor:host_or_device_scalar", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:hipblas_lt_header", - "//xla/stream_executor/rocm:amdhipblaslt_plugin", - "//xla/stream_executor:host_or_device_scalar", - "//xla/stream_executor/platform:dso_loader", + "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", + "//tensorflow/compiler/xla/stream_executor:host_or_device_scalar", + "//tensorflow/compiler/xla/stream_executor:scratch_allocator", ]) + if_static([ - "@local_tsl//tsl/platform:tensor_float_32_utils", + "//tensorflow/tsl/platform:tensor_float_32_utils", ]), ) diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index 52306d05585ebe..1f75ce8c22acc6 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/gpu_conv_padding_legalization.h" #include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/gemm_algorithm_picker.h" #include "xla/service/gpu/gpu_layout_assignment.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/reduction_degenerate_dim_remover.h" @@ -155,8 +156,7 @@ Status AMDGPUCompiler::AddConvAndGemmAutotuningPasses( if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) { pipeline->AddPass(autotune_config); } - // TODO: - // pipeline->AddPass(autotune_config); + pipeline->AddPass(autotune_config); return OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cc index b150f858cc136b..f126ae5ec5459f 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,1078 +13,88 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/buffer_comparator.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include #include -#include - -#include "absl/base/call_once.h" -#include "absl/strings/str_replace.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/hlo_module_config.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/asm_compiler.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" +#include +#include +#include +#include + +#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/kernel.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { -static constexpr double kTolerance = 0.1f; - -// Comparison kernel code: compare two buffers of -// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the -// relative error does not exceed the passed rel_error_threshold. Write the -// number of mismatches into out parameter mismatch_count. - -// NaN's are considered equal, and for half's we clamp all numbers to largest -// and smallest numbers representable to avoid miscomparisons due to overflows. -// -// The PTX below is compiled from the CUDA code below. The following command was -// used with NVCC from CUDA 11.8 -// -// nvcc --gpu-architecture=compute_50 --ptx buffer_compare.cu -// -// The CUDA code follows: - -// #include -// #include -// #include -// -// namespace { -// -// __device__ __inline__ float __xla_buffer_comparator_canonicalize(float input) -// { -// // All fp16 infinities are treated as 65505 or -65505, in order to avoid -// // differences due to overflows. -// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); -// } -// -// } // end anonymous namespace -// -// extern "C" { // avoid name mangling -// -// -// __global__ void __xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t *buffer_a, -// __nv_fp8_storage_t *buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int *mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; -// // TODO(philipphack): Replace with direct conversion to float when this -// // functionality becomes availabe. -// float elem_a = -// __half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E4M3)); -// float elem_b = -// __half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E4M3)); -// elem_a = __xla_buffer_comparator_canonicalize(elem_a); -// elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) -// return; -// -// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); -// -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_fp8_e5m2_comparison(__nv_fp8_storage_t *buffer_a, -// __nv_fp8_storage_t *buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int *mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; -// // TODO(philipphack): Replace with direct conversion to float when this -// // functionality becomes availabe. -// float elem_a = -// __half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E5M2)); -// float elem_b = -// __half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E5M2)); -// elem_a = __xla_buffer_comparator_canonicalize(elem_a); -// elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) -// return; -// -// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); -// -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = __half2float(buffer_a[idx]); -// float elem_b = __half2float(buffer_b[idx]); -// elem_a = __xla_buffer_comparator_canonicalize(elem_a); -// elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = buffer_a[idx]; -// float elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) return; -// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) -// return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// -// double elem_a = buffer_a[idx]; -// double elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) return; -// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) -// return; -// double rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_bf16_comparison(__nv_bfloat16* buffer_a, -// __nv_bfloat16* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = __bfloat162float(buffer_a[idx]); -// float elem_b = __bfloat162float(buffer_b[idx]); -// elem_a = __xla_buffer_comparator_canonicalize(elem_a); -// elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// // TODO(b/191520348): The comparison below requires exact equality. -// __global__ void __xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float a = buffer_a[idx]; -// float b = buffer_b[idx]; -// float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_int32_comparison(int* buffer_a, int* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = static_cast(buffer_a[idx]); -// float elem_b = static_cast(buffer_b[idx]); -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// } // end extern declaration - -static const char* buffer_compare_ptx = R"( -// -// Generated by NVIDIA NVVM Compiler -// -// Compiler Build ID: CL-31833905 -// Cuda compilation tools, release 11.8, V11.8.89 -// Based on NVVM 7.0.1 -// - -.version 7.8 -.target sm_50 -.address_size 64 - - // .globl __xla_fp8_e4m3fn_comparison - -.visible .entry __xla_fp8_e4m3fn_comparison( - .param .u64 __xla_fp8_e4m3fn_comparison_param_0, - .param .u64 __xla_fp8_e4m3fn_comparison_param_1, - .param .f32 __xla_fp8_e4m3fn_comparison_param_2, - .param .u64 __xla_fp8_e4m3fn_comparison_param_3, - .param .u64 __xla_fp8_e4m3fn_comparison_param_4 -) -{ - .reg .pred %p<19>; - .reg .b16 %rs<71>; - .reg .f32 %f<30>; - .reg .b32 %r<6>; - .reg .b64 %rd<11>; - - - ld.param.u64 %rd2, [__xla_fp8_e4m3fn_comparison_param_0]; - ld.param.u64 %rd3, [__xla_fp8_e4m3fn_comparison_param_1]; - ld.param.f32 %f12, [__xla_fp8_e4m3fn_comparison_param_2]; - ld.param.u64 %rd5, [__xla_fp8_e4m3fn_comparison_param_3]; - ld.param.u64 %rd4, [__xla_fp8_e4m3fn_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB0_27; - - cvta.to.global.u64 %rd6, %rd2; - add.s64 %rd7, %rd6, %rd1; - ld.global.u8 %rs1, [%rd7]; - shl.b16 %rs38, %rs1, 8; - and.b16 %rs2, %rs38, -32768; - and.b16 %rs3, %rs38, 30720; - shr.u16 %rs39, %rs3, 1; - add.s16 %rs59, %rs39, 8192; - and.b16 %rs60, %rs38, 1792; - shr.u16 %rs62, %rs60, 1; - and.b16 %rs40, %rs1, 127; - setp.eq.s16 %p2, %rs40, 127; - mov.u16 %rs70, 32767; - mov.u16 %rs63, %rs70; - @%p2 bra $L__BB0_10; - - setp.eq.s16 %p3, %rs3, 0; - @%p3 bra $L__BB0_4; - - or.b16 %rs41, %rs62, %rs2; - or.b16 %rs63, %rs41, %rs59; - bra.uni $L__BB0_10; - -$L__BB0_4: - setp.eq.s16 %p4, %rs60, 0; - mov.u16 %rs61, 0; - @%p4 bra $L__BB0_9; - - and.b16 %rs43, %rs1, 4; - setp.ne.s16 %p5, %rs43, 0; - @%p5 bra $L__BB0_8; - - mov.u16 %rs57, %rs60; - -$L__BB0_7: - shl.b16 %rs60, %rs57, 1; - add.s16 %rs59, %rs59, -1024; - and.b16 %rs44, %rs57, 512; - setp.eq.s16 %p6, %rs44, 0; - mov.u16 %rs57, %rs60; - @%p6 bra $L__BB0_7; - -$L__BB0_8: - and.b16 %rs62, %rs60, 1022; - mov.u16 %rs61, %rs59; - -$L__BB0_9: - or.b16 %rs45, %rs61, %rs2; - or.b16 %rs63, %rs45, %rs62; - -$L__BB0_10: - // begin inline asm - { cvt.f32.f16 %f27, %rs63;} - - // end inline asm - cvta.to.global.u64 %rd8, %rd3; - add.s64 %rd9, %rd8, %rd1; - ld.global.u8 %rs19, [%rd9]; - shl.b16 %rs48, %rs19, 8; - and.b16 %rs20, %rs48, -32768; - and.b16 %rs21, %rs48, 30720; - shr.u16 %rs49, %rs21, 1; - add.s16 %rs66, %rs49, 8192; - and.b16 %rs67, %rs48, 1792; - shr.u16 %rs69, %rs67, 1; - and.b16 %rs50, %rs19, 127; - setp.eq.s16 %p7, %rs50, 127; - @%p7 bra $L__BB0_19; - - setp.eq.s16 %p8, %rs21, 0; - @%p8 bra $L__BB0_13; - - or.b16 %rs51, %rs69, %rs20; - or.b16 %rs70, %rs51, %rs66; - bra.uni $L__BB0_19; - -$L__BB0_13: - setp.eq.s16 %p9, %rs67, 0; - mov.u16 %rs68, 0; - @%p9 bra $L__BB0_18; - - and.b16 %rs53, %rs19, 4; - setp.ne.s16 %p10, %rs53, 0; - @%p10 bra $L__BB0_17; - - mov.u16 %rs64, %rs67; - -$L__BB0_16: - shl.b16 %rs67, %rs64, 1; - add.s16 %rs66, %rs66, -1024; - and.b16 %rs54, %rs64, 512; - setp.eq.s16 %p11, %rs54, 0; - mov.u16 %rs64, %rs67; - @%p11 bra $L__BB0_16; - -$L__BB0_17: - and.b16 %rs69, %rs67, 1022; - mov.u16 %rs68, %rs66; - -$L__BB0_18: - or.b16 %rs55, %rs68, %rs20; - or.b16 %rs70, %rs55, %rs69; - -$L__BB0_19: - // begin inline asm - { cvt.f32.f16 %f29, %rs70;} - - // end inline asm - abs.f32 %f15, %f27; - setp.gtu.f32 %p12, %f15, 0f7F800000; - @%p12 bra $L__BB0_21; - - mov.f32 %f16, 0f477FE100; - min.f32 %f17, %f27, %f16; - mov.f32 %f18, 0fC77FE100; - max.f32 %f27, %f18, %f17; - -$L__BB0_21: - abs.f32 %f28, %f29; - setp.gtu.f32 %p13, %f28, 0f7F800000; - @%p13 bra $L__BB0_23; - - mov.f32 %f19, 0f477FE100; - min.f32 %f20, %f29, %f19; - mov.f32 %f21, 0fC77FE100; - max.f32 %f29, %f21, %f20; - abs.f32 %f28, %f29; - -$L__BB0_23: - abs.f32 %f10, %f27; - setp.gtu.f32 %p14, %f10, 0f7F800000; - setp.gtu.f32 %p15, %f28, 0f7F800000; - and.pred %p16, %p14, %p15; - @%p16 bra $L__BB0_27; - - sub.f32 %f22, %f27, %f29; - abs.f32 %f23, %f22; - max.f32 %f24, %f10, %f28; - add.f32 %f25, %f24, 0f3F800000; - div.rn.f32 %f11, %f23, %f25; - setp.gt.f32 %p17, %f11, %f12; - @%p17 bra $L__BB0_26; - - abs.f32 %f26, %f11; - setp.le.f32 %p18, %f26, 0f7F800000; - @%p18 bra $L__BB0_27; - -$L__BB0_26: - cvta.to.global.u64 %rd10, %rd4; - atom.global.add.u32 %r5, [%rd10], 1; - -$L__BB0_27: - ret; - -} - // .globl __xla_fp8_e5m2_comparison -.visible .entry __xla_fp8_e5m2_comparison( - .param .u64 __xla_fp8_e5m2_comparison_param_0, - .param .u64 __xla_fp8_e5m2_comparison_param_1, - .param .f32 __xla_fp8_e5m2_comparison_param_2, - .param .u64 __xla_fp8_e5m2_comparison_param_3, - .param .u64 __xla_fp8_e5m2_comparison_param_4 -) -{ - .reg .pred %p<11>; - .reg .b16 %rs<9>; - .reg .f32 %f<30>; - .reg .b32 %r<6>; - .reg .b64 %rd<11>; - - - ld.param.u64 %rd2, [__xla_fp8_e5m2_comparison_param_0]; - ld.param.u64 %rd3, [__xla_fp8_e5m2_comparison_param_1]; - ld.param.f32 %f12, [__xla_fp8_e5m2_comparison_param_2]; - ld.param.u64 %rd5, [__xla_fp8_e5m2_comparison_param_3]; - ld.param.u64 %rd4, [__xla_fp8_e5m2_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB1_9; - - cvta.to.global.u64 %rd6, %rd2; - add.s64 %rd7, %rd6, %rd1; - ld.global.u8 %rs3, [%rd7]; - shl.b16 %rs4, %rs3, 8; - and.b16 %rs5, %rs3, 127; - setp.gt.u16 %p2, %rs5, 124; - selp.b16 %rs1, 32767, %rs4, %p2; - // begin inline asm - { cvt.f32.f16 %f27, %rs1;} - - // end inline asm - cvta.to.global.u64 %rd8, %rd3; - add.s64 %rd9, %rd8, %rd1; - ld.global.u8 %rs6, [%rd9]; - shl.b16 %rs7, %rs6, 8; - and.b16 %rs8, %rs6, 127; - setp.gt.u16 %p3, %rs8, 124; - selp.b16 %rs2, 32767, %rs7, %p3; - // begin inline asm - { cvt.f32.f16 %f29, %rs2;} - - // end inline asm - abs.f32 %f15, %f27; - setp.gtu.f32 %p4, %f15, 0f7F800000; - @%p4 bra $L__BB1_3; - - mov.f32 %f16, 0f477FE100; - min.f32 %f17, %f27, %f16; - mov.f32 %f18, 0fC77FE100; - max.f32 %f27, %f18, %f17; - -$L__BB1_3: - abs.f32 %f28, %f29; - setp.gtu.f32 %p5, %f28, 0f7F800000; - @%p5 bra $L__BB1_5; - - mov.f32 %f19, 0f477FE100; - min.f32 %f20, %f29, %f19; - mov.f32 %f21, 0fC77FE100; - max.f32 %f29, %f21, %f20; - abs.f32 %f28, %f29; - -$L__BB1_5: - abs.f32 %f10, %f27; - setp.gtu.f32 %p6, %f10, 0f7F800000; - setp.gtu.f32 %p7, %f28, 0f7F800000; - and.pred %p8, %p6, %p7; - @%p8 bra $L__BB1_9; - - sub.f32 %f22, %f27, %f29; - abs.f32 %f23, %f22; - max.f32 %f24, %f10, %f28; - add.f32 %f25, %f24, 0f3F800000; - div.rn.f32 %f11, %f23, %f25; - setp.gt.f32 %p9, %f11, %f12; - @%p9 bra $L__BB1_8; - - abs.f32 %f26, %f11; - setp.le.f32 %p10, %f26, 0f7F800000; - @%p10 bra $L__BB1_9; - -$L__BB1_8: - cvta.to.global.u64 %rd10, %rd4; - atom.global.add.u32 %r5, [%rd10], 1; - -$L__BB1_9: - ret; - -})" - R"( - // .globl __xla_fp16_comparison -.visible .entry __xla_fp16_comparison( - .param .u64 __xla_fp16_comparison_param_0, - .param .u64 __xla_fp16_comparison_param_1, - .param .f32 __xla_fp16_comparison_param_2, - .param .u64 __xla_fp16_comparison_param_3, - .param .u64 __xla_fp16_comparison_param_4 -) -{ - .reg .pred %p<9>; - .reg .b16 %rs<3>; - .reg .f32 %f<30>; - .reg .b32 %r<6>; - .reg .b64 %rd<12>; - - - ld.param.u64 %rd2, [__xla_fp16_comparison_param_0]; - ld.param.u64 %rd3, [__xla_fp16_comparison_param_1]; - ld.param.f32 %f12, [__xla_fp16_comparison_param_2]; - ld.param.u64 %rd5, [__xla_fp16_comparison_param_3]; - ld.param.u64 %rd4, [__xla_fp16_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB2_9; - - cvta.to.global.u64 %rd6, %rd2; - shl.b64 %rd7, %rd1, 1; - add.s64 %rd8, %rd6, %rd7; - ld.global.u16 %rs1, [%rd8]; - // begin inline asm - { cvt.f32.f16 %f27, %rs1;} - - // end inline asm - cvta.to.global.u64 %rd9, %rd3; - add.s64 %rd10, %rd9, %rd7; - ld.global.u16 %rs2, [%rd10]; - // begin inline asm - { cvt.f32.f16 %f29, %rs2;} - - // end inline asm - abs.f32 %f15, %f27; - setp.gtu.f32 %p2, %f15, 0f7F800000; - @%p2 bra $L__BB2_3; - - mov.f32 %f16, 0f477FE100; - min.f32 %f17, %f27, %f16; - mov.f32 %f18, 0fC77FE100; - max.f32 %f27, %f18, %f17; - -$L__BB2_3: - abs.f32 %f28, %f29; - setp.gtu.f32 %p3, %f28, 0f7F800000; - @%p3 bra $L__BB2_5; - - mov.f32 %f19, 0f477FE100; - min.f32 %f20, %f29, %f19; - mov.f32 %f21, 0fC77FE100; - max.f32 %f29, %f21, %f20; - abs.f32 %f28, %f29; - -$L__BB2_5: - abs.f32 %f10, %f27; - setp.gtu.f32 %p4, %f10, 0f7F800000; - setp.gtu.f32 %p5, %f28, 0f7F800000; - and.pred %p6, %p4, %p5; - @%p6 bra $L__BB2_9; - - sub.f32 %f22, %f27, %f29; - abs.f32 %f23, %f22; - max.f32 %f24, %f10, %f28; - add.f32 %f25, %f24, 0f3F800000; - div.rn.f32 %f11, %f23, %f25; - setp.gt.f32 %p7, %f11, %f12; - @%p7 bra $L__BB2_8; - - abs.f32 %f26, %f11; - setp.le.f32 %p8, %f26, 0f7F800000; - @%p8 bra $L__BB2_9; - -$L__BB2_8: - cvta.to.global.u64 %rd11, %rd4; - atom.global.add.u32 %r5, [%rd11], 1; - -$L__BB2_9: - ret; - -} - // .globl __xla_fp32_comparison -.visible .entry __xla_fp32_comparison( - .param .u64 __xla_fp32_comparison_param_0, - .param .u64 __xla_fp32_comparison_param_1, - .param .f32 __xla_fp32_comparison_param_2, - .param .u64 __xla_fp32_comparison_param_3, - .param .u64 __xla_fp32_comparison_param_4 -) -{ - .reg .pred %p<10>; - .reg .b16 %rs<3>; - .reg .f32 %f<15>; - .reg .b32 %r<10>; - .reg .b64 %rd<12>; - - - ld.param.u64 %rd2, [__xla_fp32_comparison_param_0]; - ld.param.u64 %rd3, [__xla_fp32_comparison_param_1]; - ld.param.f32 %f7, [__xla_fp32_comparison_param_2]; - ld.param.u64 %rd5, [__xla_fp32_comparison_param_3]; - ld.param.u64 %rd4, [__xla_fp32_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB3_8; - - cvta.to.global.u64 %rd6, %rd2; - shl.b64 %rd7, %rd1, 2; - add.s64 %rd8, %rd6, %rd7; - cvta.to.global.u64 %rd9, %rd3; - add.s64 %rd10, %rd9, %rd7; - ld.global.f32 %f1, [%rd8]; - abs.f32 %f2, %f1; - setp.le.f32 %p2, %f2, 0f7F800000; - ld.global.f32 %f3, [%rd10]; - abs.f32 %f14, %f3; - @%p2 bra $L__BB3_3; - - setp.gtu.f32 %p3, %f14, 0f7F800000; - @%p3 bra $L__BB3_8; - -$L__BB3_3: - setp.neu.f32 %p4, %f2, 0f7F800000; - setp.neu.f32 %p5, %f14, 0f7F800000; - or.pred %p6, %p4, %p5; - @%p6 bra $L__BB3_5; - - mov.b32 %r5, %f1; - shr.u32 %r6, %r5, 31; - cvt.u16.u32 %rs1, %r6; - mov.b32 %r7, %f3; - shr.u32 %r8, %r7, 31; - cvt.u16.u32 %rs2, %r8; - setp.eq.s16 %p7, %rs1, %rs2; - mov.f32 %f14, 0f7F800000; - @%p7 bra $L__BB3_8; - -$L__BB3_5: - sub.f32 %f9, %f1, %f3; - abs.f32 %f10, %f9; - max.f32 %f11, %f2, %f14; - add.f32 %f12, %f11, 0f3F800000; - div.rn.f32 %f6, %f10, %f12; - setp.gt.f32 %p8, %f6, %f7; - @%p8 bra $L__BB3_7; - - abs.f32 %f13, %f6; - setp.le.f32 %p9, %f13, 0f7F800000; - @%p9 bra $L__BB3_8; - -$L__BB3_7: - cvta.to.global.u64 %rd11, %rd4; - atom.global.add.u32 %r9, [%rd11], 1; - -$L__BB3_8: - ret; - -} - // .globl __xla_fp64_comparison -.visible .entry __xla_fp64_comparison( - .param .u64 __xla_fp64_comparison_param_0, - .param .u64 __xla_fp64_comparison_param_1, - .param .f32 __xla_fp64_comparison_param_2, - .param .u64 __xla_fp64_comparison_param_3, - .param .u64 __xla_fp64_comparison_param_4 -) -{ - .reg .pred %p<13>; - .reg .b16 %rs<3>; - .reg .f32 %f<2>; - .reg .b32 %r<14>; - .reg .f64 %fd<13>; - .reg .b64 %rd<12>; - - - ld.param.u64 %rd2, [__xla_fp64_comparison_param_0]; - ld.param.u64 %rd3, [__xla_fp64_comparison_param_1]; - ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; - ld.param.u64 %rd5, [__xla_fp64_comparison_param_3]; - ld.param.u64 %rd4, [__xla_fp64_comparison_param_4]; - mov.u32 %r3, %ntid.x; - mov.u32 %r4, %ctaid.x; - mov.u32 %r5, %tid.x; - mad.lo.s32 %r6, %r4, %r3, %r5; - cvt.s64.s32 %rd1, %r6; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB4_9; - - cvta.to.global.u64 %rd6, %rd2; - shl.b64 %rd7, %rd1, 3; - add.s64 %rd8, %rd6, %rd7; - cvta.to.global.u64 %rd9, %rd3; - add.s64 %rd10, %rd9, %rd7; - ld.global.f64 %fd1, [%rd10]; - ld.global.f64 %fd2, [%rd8]; - abs.f64 %fd3, %fd2; - setp.le.f64 %p2, %fd3, 0d7FF0000000000000; - @%p2 bra $L__BB4_3; - - abs.f64 %fd5, %fd1; - setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000; - @%p3 bra $L__BB4_9; - -$L__BB4_3: - { - .reg .b32 %temp; - mov.b64 {%r7, %temp}, %fd2; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r1}, %fd2; - } - and.b32 %r8, %r1, 2147483647; - setp.ne.s32 %p4, %r8, 2146435072; - setp.ne.s32 %p5, %r7, 0; - or.pred %p6, %p4, %p5; - @%p6 bra $L__BB4_6; - - { - .reg .b32 %temp; - mov.b64 {%r9, %temp}, %fd1; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r2}, %fd1; - } - and.b32 %r10, %r2, 2147483647; - setp.ne.s32 %p7, %r10, 2146435072; - setp.ne.s32 %p8, %r9, 0; - or.pred %p9, %p7, %p8; - @%p9 bra $L__BB4_6; - - shr.u32 %r11, %r1, 31; - cvt.u16.u32 %rs1, %r11; - shr.u32 %r12, %r2, 31; - cvt.u16.u32 %rs2, %r12; - setp.eq.s16 %p10, %rs1, %rs2; - @%p10 bra $L__BB4_9; - -$L__BB4_6: - sub.f64 %fd6, %fd2, %fd1; - abs.f64 %fd7, %fd6; - abs.f64 %fd8, %fd1; - max.f64 %fd9, %fd3, %fd8; - add.f64 %fd10, %fd9, 0d3FF0000000000000; - div.rn.f64 %fd4, %fd7, %fd10; - cvt.f64.f32 %fd11, %f1; - setp.gt.f64 %p11, %fd4, %fd11; - @%p11 bra $L__BB4_8; - - abs.f64 %fd12, %fd4; - setp.le.f64 %p12, %fd12, 0d7FF0000000000000; - @%p12 bra $L__BB4_9; - -$L__BB4_8: - cvta.to.global.u64 %rd11, %rd4; - atom.global.add.u32 %r13, [%rd11], 1; - -$L__BB4_9: - ret; - -} - // .globl __xla_bf16_comparison -.visible .entry __xla_bf16_comparison( - .param .u64 __xla_bf16_comparison_param_0, - .param .u64 __xla_bf16_comparison_param_1, - .param .f32 __xla_bf16_comparison_param_2, - .param .u64 __xla_bf16_comparison_param_3, - .param .u64 __xla_bf16_comparison_param_4 -) -{ - .reg .pred %p<9>; - .reg .b16 %rs<3>; - .reg .f32 %f<30>; - .reg .b32 %r<6>; - .reg .b64 %rd<12>; - - - ld.param.u64 %rd2, [__xla_bf16_comparison_param_0]; - ld.param.u64 %rd3, [__xla_bf16_comparison_param_1]; - ld.param.f32 %f12, [__xla_bf16_comparison_param_2]; - ld.param.u64 %rd5, [__xla_bf16_comparison_param_3]; - ld.param.u64 %rd4, [__xla_bf16_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB5_9; - - cvta.to.global.u64 %rd6, %rd2; - shl.b64 %rd7, %rd1, 1; - add.s64 %rd8, %rd6, %rd7; - ld.global.u16 %rs1, [%rd8]; - // begin inline asm - { mov.b32 %f27, {0,%rs1};} - - // end inline asm - cvta.to.global.u64 %rd9, %rd3; - add.s64 %rd10, %rd9, %rd7; - ld.global.u16 %rs2, [%rd10]; - // begin inline asm - { mov.b32 %f29, {0,%rs2};} - - // end inline asm - abs.f32 %f15, %f27; - setp.gtu.f32 %p2, %f15, 0f7F800000; - @%p2 bra $L__BB5_3; - - mov.f32 %f16, 0f477FE100; - min.f32 %f17, %f27, %f16; - mov.f32 %f18, 0fC77FE100; - max.f32 %f27, %f18, %f17; - -$L__BB5_3: - abs.f32 %f28, %f29; - setp.gtu.f32 %p3, %f28, 0f7F800000; - @%p3 bra $L__BB5_5; - - mov.f32 %f19, 0f477FE100; - min.f32 %f20, %f29, %f19; - mov.f32 %f21, 0fC77FE100; - max.f32 %f29, %f21, %f20; - abs.f32 %f28, %f29; - -$L__BB5_5: - abs.f32 %f10, %f27; - setp.gtu.f32 %p4, %f10, 0f7F800000; - setp.gtu.f32 %p5, %f28, 0f7F800000; - and.pred %p6, %p4, %p5; - @%p6 bra $L__BB5_9; - - sub.f32 %f22, %f27, %f29; - abs.f32 %f23, %f22; - max.f32 %f24, %f10, %f28; - add.f32 %f25, %f24, 0f3F800000; - div.rn.f32 %f11, %f23, %f25; - setp.gt.f32 %p7, %f11, %f12; - @%p7 bra $L__BB5_8; - - abs.f32 %f26, %f11; - setp.le.f32 %p8, %f26, 0f7F800000; - @%p8 bra $L__BB5_9; - -$L__BB5_8: - cvta.to.global.u64 %rd11, %rd4; - atom.global.add.u32 %r5, [%rd11], 1; - -$L__BB5_9: - ret; - -} - // .globl __xla_int8_comparison -.visible .entry __xla_int8_comparison( - .param .u64 __xla_int8_comparison_param_0, - .param .u64 __xla_int8_comparison_param_1, - .param .f32 __xla_int8_comparison_param_2, - .param .u64 __xla_int8_comparison_param_3, - .param .u64 __xla_int8_comparison_param_4 -) -{ - .reg .pred %p<4>; - .reg .b16 %rs<3>; - .reg .f32 %f<12>; - .reg .b32 %r<6>; - .reg .b64 %rd<11>; - - - ld.param.u64 %rd2, [__xla_int8_comparison_param_0]; - ld.param.u64 %rd3, [__xla_int8_comparison_param_1]; - ld.param.f32 %f2, [__xla_int8_comparison_param_2]; - ld.param.u64 %rd5, [__xla_int8_comparison_param_3]; - ld.param.u64 %rd4, [__xla_int8_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB6_4; - - cvta.to.global.u64 %rd6, %rd2; - add.s64 %rd7, %rd6, %rd1; - ld.global.s8 %rs1, [%rd7]; - cvt.rn.f32.s16 %f3, %rs1; - cvta.to.global.u64 %rd8, %rd3; - add.s64 %rd9, %rd8, %rd1; - ld.global.s8 %rs2, [%rd9]; - cvt.rn.f32.s16 %f4, %rs2; - sub.f32 %f5, %f3, %f4; - abs.f32 %f6, %f5; - abs.f32 %f7, %f3; - abs.f32 %f8, %f4; - max.f32 %f9, %f7, %f8; - add.f32 %f10, %f9, 0f3F800000; - div.rn.f32 %f1, %f6, %f10; - setp.gt.f32 %p2, %f1, %f2; - @%p2 bra $L__BB6_3; - - abs.f32 %f11, %f1; - setp.le.f32 %p3, %f11, 0f7F800000; - @%p3 bra $L__BB6_4; - -$L__BB6_3: - cvta.to.global.u64 %rd10, %rd4; - atom.global.add.u32 %r5, [%rd10], 1; - -$L__BB6_4: - ret; - -} - // .globl __xla_int32_comparison -.visible .entry __xla_int32_comparison( - .param .u64 __xla_int32_comparison_param_0, - .param .u64 __xla_int32_comparison_param_1, - .param .f32 __xla_int32_comparison_param_2, - .param .u64 __xla_int32_comparison_param_3, - .param .u64 __xla_int32_comparison_param_4 -) -{ - .reg .pred %p<4>; - .reg .f32 %f<12>; - .reg .b32 %r<8>; - .reg .b64 %rd<12>; - - - ld.param.u64 %rd2, [__xla_int32_comparison_param_0]; - ld.param.u64 %rd3, [__xla_int32_comparison_param_1]; - ld.param.f32 %f2, [__xla_int32_comparison_param_2]; - ld.param.u64 %rd5, [__xla_int32_comparison_param_3]; - ld.param.u64 %rd4, [__xla_int32_comparison_param_4]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r2, %r1, %r3; - cvt.s64.s32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB7_4; - - cvta.to.global.u64 %rd6, %rd2; - shl.b64 %rd7, %rd1, 2; - add.s64 %rd8, %rd6, %rd7; - ld.global.u32 %r5, [%rd8]; - cvt.rn.f32.s32 %f3, %r5; - cvta.to.global.u64 %rd9, %rd3; - add.s64 %rd10, %rd9, %rd7; - ld.global.u32 %r6, [%rd10]; - cvt.rn.f32.s32 %f4, %r6; - sub.f32 %f5, %f3, %f4; - abs.f32 %f6, %f5; - abs.f32 %f7, %f3; - abs.f32 %f8, %f4; - max.f32 %f9, %f7, %f8; - add.f32 %f10, %f9, 0f3F800000; - div.rn.f32 %f1, %f6, %f10; - setp.gt.f32 %p2, %f1, %f2; - @%p2 bra $L__BB7_3; - - abs.f32 %f11, %f1; - setp.le.f32 %p3, %f11, 0f7F800000; - @%p3 bra $L__BB7_4; - -$L__BB7_3: - cvta.to.global.u64 %rd11, %rd4; - atom.global.add.u32 %r7, [%rd11], 1; - -$L__BB7_4: - ret; - -} -)"; - template using ComparisonKernelT = se::TypedKernel, se::DeviceMemory, float, uint64_t, se::DeviceMemory>; +struct ComparisonParams { + double relative_tol = 0.1; + bool verbose = true; + const Shape *shape = nullptr; + se::Stream* stream = nullptr; + se::DeviceMemoryBase current{}; + se::DeviceMemoryBase expected{}; +}; + // Compares two buffers on the GPU. // // Returns `true` if two buffers are equal, `false` otherwise. template -static StatusOr DeviceCompare(se::Stream* stream, - se::DeviceMemoryBase current, - se::DeviceMemoryBase expected, - const Shape& buffer_shape, - const HloModuleConfig& config, - absl::string_view kernel_name) { - se::StreamExecutor* executor = stream->parent(); +static StatusOr DeviceCompare( + absl::string_view kernel_name, void* kernel_symbol, + const ComparisonParams& params) { + se::StreamExecutor* executor = params.stream->parent(); se::ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); - stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); - if (current.size() != expected.size()) { - return InternalError("Mismatched buffer size: %d bytes vs. %d bytes", - current.size(), expected.size()); + params.stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); + if (params.current.size() != params.expected.size()) { + return Internal("Mismatched buffer size: %d bytes vs. %d bytes", + params.current.size(), params.expected.size()); } - se::DeviceMemory current_typed(current); - se::DeviceMemory expected_typed(expected); + se::DeviceMemory current_typed(params.current); + se::DeviceMemory expected_typed(params.expected); uint64_t buffer_size = current_typed.ElementCount(); - absl::Span compiled_ptx = {}; - StatusOr> compiled_ptx_or = - se::CompileGpuAsmOrGetCached( - executor->device_ordinal(), buffer_compare_ptx, - PtxOptsFromDebugOptions(config.debug_options())); - if (compiled_ptx_or.ok()) { - compiled_ptx = std::move(compiled_ptx_or).value(); - } else { - static absl::once_flag ptxas_not_found_logged; - absl::call_once(ptxas_not_found_logged, [&]() { - LOG(WARNING) - << compiled_ptx_or.status() - << "\nRelying on driver to perform ptx compilation. " - << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda " - << " or modifying $PATH can be used to set the location of ptxas" - << "\nThis message will only be logged once."; - }); - } - TF_ASSIGN_OR_RETURN( std::unique_ptr> comparison_kernel, (executor->CreateTypedKernel, - se::DeviceMemory, float, uint64_t, - se::DeviceMemory>( - kernel_name, buffer_compare_ptx, compiled_ptx))); + se::DeviceMemory, float, uint64_t, + se::DeviceMemory>( + kernel_name, kernel_symbol))); - const se::DeviceDescription& gpu_device_info = - executor->GetDeviceDescription(); + auto gpu_device_info = GetGpuDeviceInfo(executor); - TF_ASSIGN_OR_RETURN(LaunchDimensions dim, - CalculateLaunchDimensions(buffer_shape, gpu_device_info)); + TF_ASSIGN_OR_RETURN(auto dim, + CalculateLaunchDimensions(*params.shape, gpu_device_info)); - LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); - LaunchDimensions::Dim3D block_counts = dim.block_counts(); - TF_RETURN_IF_ERROR(stream->ThenLaunch( - se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), - se::BlockDim(block_counts.x, block_counts.y, block_counts.z), - *comparison_kernel, current_typed, expected_typed, - static_cast(kTolerance), buffer_size, out_param.cref())); + auto threads = dim.thread_counts_per_block(), + blocks = dim.block_counts(); + params.stream->ThenLaunch(se::ThreadDim(threads.x, threads.y, threads.z), + se::BlockDim(blocks.x, blocks.y, blocks.z), *comparison_kernel, + current_typed, expected_typed, static_cast(params.relative_tol), + buffer_size, out_param.cref()); uint64_t result = -1; CHECK_EQ(out_param->size(), sizeof(result)); - stream->ThenMemcpy(&result, *out_param, sizeof(result)); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + params.stream->ThenMemcpy(&result, *out_param, sizeof(result)); + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); return result == 0; } @@ -1093,13 +103,13 @@ static StatusOr DeviceCompare(se::Stream* stream, // // Returns true if no differences were seen, false otherwise. template -StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase current, - se::DeviceMemoryBase expected) { - int64_t n = current.size() / sizeof(ElementType); +static StatusOr HostCompare(const ComparisonParams& params) { + int64_t n = params.current.size() / sizeof(ElementType); std::vector host_current(n), host_expected(n); - stream->ThenMemcpy(host_current.data(), current, current.size()); - stream->ThenMemcpy(host_expected.data(), expected, expected.size()); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + + params.stream->ThenMemcpy(host_current.data(), params.current, params.current.size()); + params.stream->ThenMemcpy(host_expected.data(), params.expected, params.expected.size()); + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); const auto canonicalize = [](ComparisonType a) -> ComparisonType { if (std::is_same::value && a) { @@ -1112,6 +122,7 @@ StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase current, return a; }; int differences_seen = 0; + for (int64_t i = 0; i < n && differences_seen < 10; ++i) { auto current_value = static_cast(host_current[i]); auto expected_value = static_cast(host_expected[i]); @@ -1131,78 +142,69 @@ StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase current, !(std::abs(current_value_canonical - expected_value_canonical) / (std::max(std::abs(current_value_canonical), std::abs(expected_value_canonical)) + - 1) < - kTolerance)) { + 1) < params.relative_tol)) { + if(!params.verbose) return false; // Return immediately if not verbose. ++differences_seen; LOG(ERROR) << "Difference at " << i << ": " << current_value - << ", expected " << expected_value; + << ", expected " << expected_value; } } return differences_seen == 0; } template -static StatusOr CompareEqualParameterized(se::Stream* stream, - se::DeviceMemoryBase current, - se::DeviceMemoryBase expected, - const Shape& shape, - const HloModuleConfig& config, - absl::string_view kernel_name) { +static StatusOr CompareEqualParameterized( + absl::string_view kernel_name, void* kernel_symbol, + const ComparisonParams& params) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); - TF_ASSIGN_OR_RETURN(bool result, - DeviceCompare(stream, current, expected, shape, - config, kernel_name)); + TF_ASSIGN_OR_RETURN( + bool result, DeviceCompare(kernel_name, kernel_symbol, params)); if (result) { return true; } - TF_ASSIGN_OR_RETURN(bool host_return, (HostCompare( - stream, current, expected))); + TF_ASSIGN_OR_RETURN(bool host_return, + (HostCompare(params))); CHECK_EQ(host_return, result) << "Host comparison succeeded even though GPU comparison failed."; - return false; } StatusOr BufferComparator::CompareEqual( se::Stream* stream, se::DeviceMemoryBase current, se::DeviceMemoryBase expected) const { + + ComparisonParams params{ + relative_tol_, verbose_, &shape_, stream, current, expected}; + switch (shape_.element_type()) { - case xla::F8E4M3FN: - return CompareEqualParameterized( - stream, current, expected, shape_, config_, - "__xla_fp8_e4m3fn_comparison"); - case xla::F8E5M2: - return CompareEqualParameterized( - stream, current, expected, shape_, config_, - "__xla_fp8_e5m2_comparison"); case xla::F16: return CompareEqualParameterized( - stream, current, expected, shape_, config_, "__xla_fp16_comparison"); + "fp16_comparison", buffer_comparator::fp16_comparison(), params); case xla::BF16: - return CompareEqualParameterized( - stream, current, expected, shape_, config_, "__xla_bf16_comparison"); + return CompareEqualParameterized( + "bf16_comparison", buffer_comparator::bf16_comparison(), params); case xla::F32: return CompareEqualParameterized( - stream, current, expected, shape_, config_, "__xla_fp32_comparison"); + "fp32_comparison", buffer_comparator::fp32_comparison(), params); case xla::F64: return CompareEqualParameterized( - stream, current, expected, shape_, config_, "__xla_fp64_comparison"); + "fp64_comparison", buffer_comparator::fp64_comparison(), params); case xla::S8: return CompareEqualParameterized( - stream, current, expected, shape_, config_, "__xla_int8_comparison"); + "int8_comparison", buffer_comparator::int8_comparison(), params); case xla::S32: return CompareEqualParameterized( - stream, current, expected, shape_, config_, "__xla_int32_comparison"); + "int32_comparison", buffer_comparator::int32_comparison(), params); default: return Unimplemented("Unimplemented element type"); } } -BufferComparator::BufferComparator(const Shape& shape, - const HloModuleConfig& config) - : shape_(shape), config_(config) { +BufferComparator::BufferComparator(const Shape& shape, double tolerance, + bool verbose) : + shape_(shape), relative_tol_(tolerance), verbose_(verbose) { // Normalize complex shapes: since we treat the passed array as a contiguous // storage it does not matter which dimension are we doubling. auto double_dim_size = [&]() { diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc new file mode 100644 index 00000000000000..3345527749810f --- /dev/null +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc @@ -0,0 +1,181 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#include +#include +#include + +using bfloat16 = __nv_bfloat16; +#define BF16_TO_F32 __bfloat162float + +#elif TENSORFLOW_USE_ROCM +#include +#include + +#include "rocm/rocm_config.h" + +using bfloat16 = hip_bfloat16; +#define BF16_TO_F32 float + +#endif + +#include + +namespace xla { +namespace gpu { +namespace buffer_comparator { + +// Comparison kernel code: compare two buffers of +// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the +// relative error does not exceed the passed rel_error_threshold. Write the +// number of mismatches into out parameter mismatch_count. + +// NaN's are considered equal, and for half's we clamp all numbers to largest +// and smallest numbers representable to avoid miscomparisons due to overflows. +namespace { + +__device__ __inline__ float Canonicalize(float input) { + // All fp16 infinities are treated as 65505 or -65505, in order to avoid + // differences due to overflows. + return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); +} + +__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = __half2float(buffer_a[idx]); + float elem_b = __half2float(buffer_b[idx]); + elem_a = Canonicalize(elem_a); + elem_b = Canonicalize(elem_b); + if (isnan(elem_a) && isnan(elem_b)) return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_fp32_comparison(float* buffer_a, float* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = buffer_a[idx]; + float elem_b = buffer_b[idx]; + if (isnan(elem_a) && isnan(elem_b)) return; + if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) + return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_fp64_comparison(double* buffer_a, double* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + + double elem_a = buffer_a[idx]; + double elem_b = buffer_b[idx]; + if (isnan(elem_a) && isnan(elem_b)) return; + if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) + return; + double rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_bf16_comparison(bfloat16* buffer_a, bfloat16* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = BF16_TO_F32(buffer_a[idx]); + float elem_b = BF16_TO_F32(buffer_b[idx]); + elem_a = Canonicalize(elem_a); + elem_b = Canonicalize(elem_b); + if (isnan(elem_a) && isnan(elem_b)) return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +// TODO(b/191520348): The comparison below requires exact equality. +__global__ void xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float a = buffer_a[idx]; + float b = buffer_b[idx]; + float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_int32_comparison(int* buffer_a, int* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = static_cast(buffer_a[idx]); + float elem_b = static_cast(buffer_b[idx]); + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +} // namespace + +void* fp16_comparison() { + return reinterpret_cast(&xla_fp16_comparison); +} + +void* bf16_comparison() { + return reinterpret_cast(&xla_bf16_comparison); +} + +void* fp32_comparison() { + return reinterpret_cast(&xla_fp32_comparison); +} + +void* fp64_comparison() { + return reinterpret_cast(&xla_fp64_comparison); +} + +void* int8_comparison() { + return reinterpret_cast(&xla_int8_comparison); +} + +void* int32_comparison() { + return reinterpret_cast(&xla_int32_comparison); +} + +} // namespace buffer_comparator +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.h b/third_party/xla/xla/service/gpu/buffer_comparator.h index df4b6215aaf99b..8048cb66c5c976 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.h +++ b/third_party/xla/xla/service/gpu/buffer_comparator.h @@ -13,15 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ -#define XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ -#include "xla/service/hlo_module_config.h" -#include "xla/shape.h" -#include "xla/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape.h" -namespace xla { -namespace gpu { +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" + + +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif + +namespace xla::gpu { // A device-side comparator that compares buffers. class BufferComparator { @@ -29,7 +35,8 @@ class BufferComparator { BufferComparator(const BufferComparator&) = delete; BufferComparator(BufferComparator&&) = default; - BufferComparator(const Shape& shape, const HloModuleConfig& config); + BufferComparator(const Shape& shape, double tolerance = 0.1, + bool verbose = true); // Returns true if the two buffers compare equal. The definition of "equal" // is: @@ -40,15 +47,28 @@ class BufferComparator { // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance // // See the implementation for the tolerance value. - StatusOr CompareEqual(se::Stream* stream, se::DeviceMemoryBase current, - se::DeviceMemoryBase expected) const; + StatusOr CompareEqual(se::Stream* stream, + se::DeviceMemoryBase current, + se::DeviceMemoryBase expected) const; private: Shape shape_; - HloModuleConfig config_; + double relative_tol_; // relative tolerance for comparison + bool verbose_; // whether to print out error message on mismatch }; +namespace buffer_comparator { + +// Returns a pointer to CUDA C++ device function implementing comparison. +void* fp16_comparison(); +void* bf16_comparison(); +void* fp32_comparison(); +void* fp64_comparison(); +void* int8_comparison(); +void* int32_comparison(); + +} // namespace buffer_comparator +} // namespace xla::gpu + -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.cc b/third_party/xla/xla/service/gpu/cublas_cudnn.cc index 2c6340c6c260c4..274cc4f1112eab 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.cc +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.cc @@ -36,14 +36,8 @@ bool IsCublasLtMatmul(const HloInstruction& hlo) { hlo.custom_call_target() == kCublasLtMatmulCallTarget; } -bool IsCublasLtMatmulF8(const HloInstruction& hlo) { - return hlo.opcode() == HloOpcode::kCustomCall && - hlo.custom_call_target() == kCublasLtMatmulF8CallTarget; -} - const absl::string_view kGemmCallTarget = "__cublas$gemm"; const absl::string_view kCublasLtMatmulCallTarget = "__cublas$lt$matmul"; -const absl::string_view kCublasLtMatmulF8CallTarget = "__cublas$lt$matmul$f8"; const absl::string_view kTriangularSolveCallTarget = "__cublas$triangularSolve"; const absl::string_view kCudnnConvBackwardInputCallTarget = diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.h b/third_party/xla/xla/service/gpu/cublas_cudnn.h index 4feaf28bb1f51c..556c0bd9cac3e3 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.h +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.h @@ -85,18 +85,12 @@ bool IsLegacyCublasMatmul(const HloInstruction& hlo); // Matrix multiplication that calls into cublasLt. bool IsCublasLtMatmul(const HloInstruction& hlo); -// Scaled matrix multiplication in FP8. Calls into cublasLt. -bool IsCublasLtMatmulF8(const HloInstruction& hlo); - // A call to cuBLAS general matrix multiplication API. extern const absl::string_view kGemmCallTarget; // A call to cuBLAS Lt API matrix multiplication. extern const absl::string_view kCublasLtMatmulCallTarget; -// A call to cuBLASLt for scaled matrix multiplication in FP8. -extern const absl::string_view kCublasLtMatmulF8CallTarget; - // A call to cuBLAS for a triangular solve. // // Like cudnn convolutions, this op returns a tuple (result, scratch_memory). diff --git a/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h deleted file mode 100644 index 03b9655f8aaafe..00000000000000 --- a/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ -#define XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ - -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - -#include -#include -#include - -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "tsl/platform/statusor.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_blas_lt.h" -#else -#include "rocm/rocm_config.h" -#include "xla/stream_executor/rocm/hip_blas_lt.h" -#endif // GOOGLE_CUDA - -namespace xla { -namespace gpu { - -class CublasLtMatmulThunk : public Thunk { - public: - CublasLtMatmulThunk(ThunkInfo thunk_info, GemmConfig gemm_config, - se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, - BufferAllocation::Slice a_buffer, - BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, - BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer /* may be null */, - BufferAllocation::Slice aux_buffer /* may be null */, - BufferAllocation::Slice a_scale_buffer /* may be null */, - BufferAllocation::Slice b_scale_buffer /* may be null */, - BufferAllocation::Slice c_scale_buffer /* may be null */, - BufferAllocation::Slice d_scale_buffer /* may be null */, - BufferAllocation::Slice d_amax_buffer /* may be null */); - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - StatusOr GetMatmulPlan( - const stream_executor::Stream* stream); - - absl::Mutex matmul_plans_cache_mutex_; - absl::flat_hash_map> - matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); - - GemmConfig gemm_config_; - se::gpu::BlasLt::Epilogue epilogue_; - int64_t algorithm_idx_; - BufferAllocation::Slice a_buffer_; - BufferAllocation::Slice b_buffer_; - BufferAllocation::Slice c_buffer_; - BufferAllocation::Slice d_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice aux_buffer_; - BufferAllocation::Slice a_scale_buffer_; - BufferAllocation::Slice b_scale_buffer_; - BufferAllocation::Slice c_scale_buffer_; - BufferAllocation::Slice d_scale_buffer_; - BufferAllocation::Slice d_amax_buffer_; - std::optional algorithm_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index 07042abd81220c..5a7a7531ea6a4e 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,385 +13,384 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" -#include -#include -#include +#include +#include +#include #include #include -#include -#include #include #include #include -#include "xla/autotuning.pb.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logger.h" -#include "tsl/platform/statusor.h" -#include "tsl/util/proto/proto_utils.h" - -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "xla/service/gpu/buffer_comparator.h" -#include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" -#endif +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" +#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/protobuf/autotuning.pb.h" +#include "tensorflow/tsl/util/proto/proto_utils.h" namespace xla { namespace gpu { +namespace { -// Returns the index (into `algorithms`) of the fastest algorithm. -template -StatusOr GetBestAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, const Shape& output_shape, - const HloModuleConfig& hlo_module_config, double beta, - const std::function(const AlgoT&)>& - run_benchmark) { - if (!stream->parent()->SynchronizeAllActivity()) { - return InternalError("Failed to synchronize GPU for autotuning."); - } +using se::gpu::BlasLt; + +class GemmAutotuner { + const AutotuneConfig& autotune_config_; + RedzoneBuffers rz_buffers_; + std::unique_ptr< se::Stream > stream_; + bool deterministic_ops_ = false; + float gemm_relative_tol_ = 0.1f; + + public: + explicit GemmAutotuner(const AutotuneConfig& autotune_config) + : autotune_config_(autotune_config) {} + + StatusOr operator()( + const GemmConfig& gemm_config, + std::vector< Shape >&& input_shapes, const Shape& output_shape, + const DebugOptions& debug_options) { + + VLOG(3) << "Starting autotune of GemmThunk standalone"; + + if(!stream_) { + stream_ = std::make_unique< se::Stream >(autotune_config_.GetExecutor()); + stream_->Init(); + } - se::DeviceMemoryBase reference_buffer; - if (autotune_config.should_check_correctness()) { - TF_ASSIGN_OR_RETURN( - reference_buffer, - allocator.AllocateBytes(ShapeUtil::ByteSizeOf(output_shape))); + deterministic_ops_ = false ; + gemm_relative_tol_ = debug_options.xla_gpu_autotune_gemm_rtol(); + + // Don't run autotuning concurrently on the same GPU. + absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent())); + + TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromShapes( + std::move(input_shapes), output_shape, autotune_config_, stream_.get(), + debug_options, RedzoneBuffers::kAllInputsAllOutputs)); + + return TuneGpuBlasLt(output_shape, gemm_config); } - BufferComparator comparator(output_shape, hlo_module_config); + StatusOr operator()(const HloInstruction* gemm, + const GemmConfig& gemm_config) { - std::vector results; - std::optional reference_algorithm; + VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - for (const AlgoT& algorithm : algorithms) { - // Make sure the output buffer always has the same value if we use - // the bias parameter. - if (autotune_config.should_reinit_output_buffer() && beta != 0) { - int64_t rng_state = 0; - InitializeBuffer(stream, output_shape.element_type(), &rng_state, - output_buffer); + if(!stream_) { + stream_ = std::make_unique< se::Stream >(autotune_config_.GetExecutor()); + stream_->Init(); } - TF_ASSIGN_OR_RETURN(se::blas::ProfileResult profile_result, - run_benchmark(algorithm)); + const DebugOptions& debug_options = + gemm->GetModule()->config().debug_options(); + deterministic_ops_ = false ; + gemm_relative_tol_ = debug_options.xla_gpu_autotune_gemm_rtol(); - results.emplace_back(); - AutotuneResult& result = results.back(); - result.mutable_gemm()->set_algorithm(profile_result.algorithm()); + // Don't run autotuning concurrently on the same GPU. + absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent())); + TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromInstruction( + *gemm, autotune_config_, stream_.get(), debug_options, + RedzoneBuffers::kAllInputsAllOutputs)); - if (!profile_result.is_valid()) { // Unsupported algorithm. - result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); - continue; - } - - VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took " - << profile_result.elapsed_time_in_ms() << "ms"; + return IsCublasLtMatmul(*gemm) + ? TuneGpuBlasLt(gemm->shape(), gemm_config) + : TuneGpuBlas(gemm->shape(), gemm_config); + } - *result.mutable_run_time() = tsl::proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); + private: + se::DeviceMemoryBase LhsBuffer() { return rz_buffers_.input_buffers().at(0); } + se::DeviceMemoryBase RhsBuffer() { return rz_buffers_.input_buffers().at(1); } + se::DeviceMemoryBase OutputBuffer() { + return rz_buffers_.output_buffers().at(0); + } - if (!autotune_config.should_check_correctness()) { - continue; + StatusOr TuneGpuBlasLt(const Shape& out_shape, const GemmConfig& gemm_config) { + + se::DeviceMemoryBase workspace_buffer; + if(out_shape.IsTuple()) { + workspace_buffer = rz_buffers_.output_buffers(). + at(out_shape.tuple_shapes_size() - 1); } - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones()); - - if (!rz_check_status.ok()) { - result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); - *result.mutable_failure()->mutable_msg() = - rz_check_status.RedzoneFailureMsg(); - LOG(ERROR) << "Detected out-of-bounds write in gemm buffer"; - CHECK(!autotune_config.should_crash_on_check_failure()); - continue; - } + bool has_matrix_bias = gemm_config.beta != 0.; + bool has_vector_bias = ((int)gemm_config.epilogue & (int)BlasLt::Epilogue::kBias) != 0; + bool has_aux_output = (gemm_config.epilogue == BlasLt::Epilogue::kGELUWithAux || + gemm_config.epilogue == BlasLt::Epilogue::kBiasThenGELUWithAux); - if (!reference_algorithm) { - stream->ThenMemcpy(&reference_buffer, output_buffer, - output_buffer.size()); - reference_algorithm = profile_result.algorithm(); - } else { - // Perform the comparison. - TF_ASSIGN_OR_RETURN( - bool outputs_match, - comparator.CompareEqual(stream, /*current=*/output_buffer, - /*expected=*/reference_buffer)); - if (!outputs_match) { - LOG(ERROR) << "Results mismatch between different GEMM algorithms. " - << "This is likely a bug/unexpected loss of precision."; - CHECK(!autotune_config.should_crash_on_check_failure()); + se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, + d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer; - result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT); - result.mutable_failure()->mutable_reference_gemm()->set_algorithm( - *reference_algorithm); - } + if (has_vector_bias) { + bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2); } - } - - if (!autotune_config.should_crash_on_check_failure()) { - AutotuningLog log; - for (const AutotuneResult& result : results) { - *log.add_results() = result; + if (has_aux_output) { + aux_buffer = rz_buffers_.output_buffers().at(1); } - tsl::Logger::GetSingleton()->LogProto(log); + + TF_ASSIGN_OR_RETURN(auto plan, + BlasLt::GetMatmulPlan(stream_.get(), gemm_config)); + + TF_ASSIGN_OR_RETURN( + auto algorithms, + plan->GetAlgorithms(/*max_algorithm_count*/ se::gpu::BlasLt::kMaxAlgorithms, + /*max_workspace_size*/ workspace_buffer.size())); + + auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm) + -> StatusOr { + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithm)); + // Run a warmup iteration without the profiler active. + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_.get(), LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(), + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, + workspace_buffer)); + + se::blas::ProfileResult profile_result; + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_.get(), LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(), + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, + workspace_buffer, absl::nullopt, &profile_result)); + return std::move(profile_result); + }; + + const auto& shape = out_shape.IsTuple() ? out_shape.tuple_shapes(0) + : out_shape; + return GetBestAlgorithm( + shape, algorithms, gemm_config.beta, false, tuned_func); } - StatusOr best = - PickBestResult(results, gemm_str, hlo_module_config); - if (best.ok()) { - for (size_t i = 0; i < results.size(); ++i) { - if (best->gemm().algorithm() == results[i].gemm().algorithm()) { - best->mutable_gemm()->set_algorithm(i); - return best; - } + StatusOr TuneGpuBlas(const Shape& out_shape, + const GemmConfig& gemm_config) { +#if 0 + auto workspace_buffer = rz_buffers_.output_buffers().at(1); + + std::vector algorithms; + TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc, + gemm_config.GetMatrixDescriptors( + LhsBuffer(), RhsBuffer(), OutputBuffer())); + + auto blas = stream_->parent()->AsBlas(); + if (blas == nullptr) { + return xla::InternalError("No BLAS support for stream"); } - return InternalError("unknown best algorithm"); + blas->GetBlasGemmAlgorithms(stream_.get(), desc.lhs, desc.rhs, &desc.output, + &gemm_config.alpha, &gemm_config.beta, + &algorithms); + + auto tuned_func = [&](const se::blas::AlgorithmType& algorithm) + -> StatusOr { + // Do a warm-up run first, without a profile result. RunGemm swallows + // error codes when profile_result is passed, as it is in the measurement + // below, but not otherwise. It is, therefore, consistent to ignore the + // error code here. + static_cast(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(), + OutputBuffer(), workspace_buffer, + deterministic_ops_, stream_.get(), algorithm)); + se::blas::ProfileResult profile_result; + // Allow GpuTimer to use its delay kernel implementation to improve + // accuracy. + profile_result.set_warmup_run_executed(true); + // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail + // for all algorithms if we're targeting < sm_50. But because we pass a + // non-null ProfileResult, DoGemmWithAlgorithm should always return true, + // and the actual success-ness is returned in ProfileResult::is_valid. + TF_RETURN_IF_ERROR(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(), + OutputBuffer(), workspace_buffer, + deterministic_ops_, stream_.get(), algorithm, + &profile_result)); + return std::move(profile_result); + }; + + const auto& shape = out_shape.IsTuple() ? out_shape.tuple_shapes(0) + : out_shape; + return GetBestAlgorithm( + shape, algorithms, gemm_config.beta, false, tuned_func); +#else + return tensorflow::AutotuneResult{}; +#endif } - LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance " - "might be suboptimal: " - << best.status(); - best->clear_gemm(); - return best; -} + // Returns the index (into `algorithms`) of the fastest algorithm. + template + StatusOr GetBestAlgorithm( + const Shape& output_shape, absl::Span algorithms, + double beta, bool return_algo_index, TunedFunc&& run_benchmark) { -// Select the best algorithm using information from a Blas instruction. -// Returns the index (into `algorithms`) of the fastest algorithm. -StatusOr GetBestBlasAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, - const Shape& output_shape, const HloModuleConfig& hlo_module_config, - double beta, - const std::function( - const se::blas::AlgorithmType&)>& run_benchmark) { - return GetBestAlgorithm( - stream, allocator, gemm_str, autotune_config, lhs_buffer, rhs_buffer, - output_buffer, algorithms, output_shape, hlo_module_config, beta, - run_benchmark); -} - -namespace { + if (!stream_->parent()->SynchronizeAllActivity()) { + return Internal("Failed to synchronize GPU for autotuning."); + } -StatusOr AsBlasLtEpilogue( - GemmBackendConfig_Epilogue epilogue) { - switch (epilogue) { - case GemmBackendConfig::DEFAULT: - return se::cuda::BlasLt::Epilogue::kDefault; - case GemmBackendConfig::RELU: - return se::cuda::BlasLt::Epilogue::kReLU; - case GemmBackendConfig::GELU: - return se::cuda::BlasLt::Epilogue::kGELU; - case GemmBackendConfig::GELU_AUX: - return se::cuda::BlasLt::Epilogue::kGELUWithAux; - case GemmBackendConfig::BIAS: - return se::cuda::BlasLt::Epilogue::kBias; - case GemmBackendConfig::BIAS_RELU: - return se::cuda::BlasLt::Epilogue::kBiasThenReLU; - case GemmBackendConfig::BIAS_GELU: - return se::cuda::BlasLt::Epilogue::kBiasThenGELU; - case GemmBackendConfig::BIAS_GELU_AUX: - return se::cuda::BlasLt::Epilogue::kBiasThenGELUWithAux; - default: - return InternalError("Unsupported Epilogue."); - } -} + se::DeviceMemoryBase reference_buffer; + if (autotune_config_.should_check_correctness()) { + TF_ASSIGN_OR_RETURN(reference_buffer, + rz_buffers_.RedzoneAllocator().AllocateBytes( + ShapeUtil::ByteSizeOf(output_shape))); + } -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) + // Do not print error messages if should_skip_wrong_results() is ON. + BufferComparator comparator(output_shape, gemm_relative_tol_, + /* verbose */!autotune_config_.should_skip_wrong_results() + ); + std::vector results; + results.reserve(algorithms.size()); + absl::optional reference_algorithm; + + for (size_t i = 0; i < algorithms.size(); i++) { + const AlgoT& algorithm = algorithms[i]; + // Make sure the output buffer always has the same value if we use + // the bias parameter. + if (autotune_config_.should_reinit_output_buffer() && beta != 0) { + int64_t rng_state = 0; + InitializeBuffer(stream_.get(), output_shape.element_type(), &rng_state, + OutputBuffer()); + } + TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm)); -StatusOr DoGemmAutotuneNoCache( - const HloInstruction* gemm, const AutotuneCacheKey& key, - const AutotuneConfig& autotune_config) { - if (autotune_config.IsDeviceless()) { - // Return empty result, will tune at runtime. - return AutotuneResult{}; - } + results.emplace_back(); + tensorflow::AutotuneResult& result = results.back(); + result.mutable_gemm()->set_algorithm(profile_result.algorithm()); - VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream()); - GemmBackendConfig gemm_config = - gemm->backend_config().value(); - const DebugOptions& debug_options = - gemm->GetModule()->config().debug_options(); - const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); - - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); - // Don't run autotuning concurrently on the same GPU. - absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent())); - - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator buffer_allocator, - AutotunerUtil::CreateRedzoneAllocator(autotune_config, debug_options)); - - int64_t rng_state = 0; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase lhs_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, gemm->operand(0)->shape(), - autotune_config, rng_state)); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase rhs_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, gemm->operand(1)->shape(), - autotune_config, rng_state)); - - const Shape& output_shape = - gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); - - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase output_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, output_shape, - autotune_config, rng_state)); - - HloModuleConfig& hlo_module_config = gemm->GetModule()->config(); - AutotuneResult best_algorithm; - if (IsCublasLtMatmul(*gemm)) { - bool has_matrix_bias = config.beta != 0.; - - TF_ASSIGN_OR_RETURN(bool has_vector_bias, cublas_lt::EpilogueAddsVectorBias( - gemm_config.epilogue())); + if (!profile_result.is_valid()) { // Unsupported algorithm. + result.mutable_failure()->set_kind(tensorflow::AutotuneResult::DISQUALIFIED); + continue; + } - TF_ASSIGN_OR_RETURN( - bool has_aux_output, - cublas_lt::EpilogueHasAuxiliaryOutput(gemm_config.epilogue())); + VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took " + << profile_result.elapsed_time_in_ms() << "ms"; - TF_ASSIGN_OR_RETURN(auto epilogue, - AsBlasLtEpilogue(gemm_config.epilogue())); + *result.mutable_run_time() = tsl::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); - se::DeviceMemoryBase bias_buffer; - if (has_vector_bias) { + if (!autotune_config_.should_check_correctness()) { + continue; + } TF_ASSIGN_OR_RETURN( - bias_buffer, - AutotunerUtil::CreateBuffer( - buffer_allocator, gemm->operand(has_matrix_bias ? 3 : 2)->shape(), - autotune_config, rng_state)); - } - se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, - d_scale_buffer, d_amax_buffer; + se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, + rz_buffers_.RedzoneAllocator().CheckRedzones()); + + if (!rz_check_status.ok()) { + result.mutable_failure()->set_kind(tensorflow::AutotuneResult::REDZONE_MODIFIED); + *result.mutable_failure()->mutable_msg() = + rz_check_status.RedzoneFailureMsg(); + LOG(ERROR) << "Detected out-of-bounds write in gemm buffer"; + CHECK(!autotune_config_.should_crash_on_check_failure()); + continue; + } - se::DeviceMemoryBase aux_buffer; - if (has_aux_output) { + if (!reference_algorithm) { + stream_->ThenMemcpy(&reference_buffer, OutputBuffer(), + OutputBuffer().size()); + reference_algorithm = profile_result.algorithm(); + continue; + } + // Perform the comparison versus the reference algorithm. TF_ASSIGN_OR_RETURN( - aux_buffer, AutotunerUtil::CreateBuffer(buffer_allocator, - gemm->shape().tuple_shapes(1), - autotune_config, rng_state)); - } - - TF_ASSIGN_OR_RETURN(auto plan, - cublas_lt::MatmulPlan::From(config, epilogue)); - TF_ASSIGN_OR_RETURN( - std::vector algorithms, - plan.GetAlgorithms(stream)); - - TF_ASSIGN_OR_RETURN( - best_algorithm, - GetBestAlgorithm( - stream, buffer_allocator, gemm->ToString(), autotune_config, - lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), - [&](const se::cuda::BlasLt::MatmulAlgorithm& algorithm) - -> StatusOr { - se::OwningScratchAllocator<> scratch_allocator( - stream->parent()->device_ordinal(), allocator); - se::blas::ProfileResult profile_result; - TF_RETURN_IF_ERROR(plan.ExecuteOnStream( - stream, lhs_buffer, rhs_buffer, output_buffer, output_buffer, - bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, - c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, - scratch_allocator, &profile_result)); - return std::move(profile_result); - })); - } else { - std::vector algorithms; - TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)); - - TF_ASSIGN_OR_RETURN(best_algorithm, - GetBestBlasAlgorithm( - stream, buffer_allocator, gemm->ToString(), - autotune_config, lhs_buffer, rhs_buffer, - output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), - [&](const se::blas::AlgorithmType& algorithm) - -> StatusOr { - se::blas::ProfileResult profile_result; - // We expect GemmWithAlgorithm to fail sometimes - // -- in fact, it will fail for all algorithms if - // we're targeting < sm_50. But because we pass a - // non-null ProfileResult, DoGemmWithAlgorithm - // should always return true, and the actual - // success-ness is returned in - // ProfileResult::is_valid. - TF_RETURN_IF_ERROR( - RunGemm(config, lhs_buffer, rhs_buffer, - output_buffer, deterministic_ops, - stream, algorithm, &profile_result)); - return std::move(profile_result); - })); - if (best_algorithm.has_gemm()) { - int alg_idx = best_algorithm.gemm().algorithm(); - best_algorithm.mutable_gemm()->set_algorithm(algorithms[alg_idx]); + bool outputs_match, + comparator.CompareEqual(stream_.get(), /*current=*/OutputBuffer(), + /*expected=*/reference_buffer)); + if (!outputs_match) { + LOG(ERROR) << "Results mismatch between different GEMM algorithms. " + << "This is likely a bug/unexpected loss of precision."; + CHECK(!autotune_config_.should_crash_on_check_failure()); + + // By default, autotuner does NOT really skip wrong results, but + // merely prints out the above error message: this may lead to a + // great confusion. When should_skip_wrong_results() is set to true, + // solutions with accuracy problems will be disqualified. + auto kind = tensorflow::AutotuneResult::WRONG_RESULT; + if (autotune_config_.should_skip_wrong_results()) { + kind = tensorflow::AutotuneResult::DISQUALIFIED; + } + result.mutable_failure()->set_kind(kind); + result.mutable_failure()->mutable_reference_gemm()->set_algorithm( + *reference_algorithm); + } + } // for algorithms + + StatusOr best_res = + PickBestResult(results, absl::nullopt); + if (best_res.ok()) { + auto best = std::move(best_res.value()); + // Return a real algorithm ID if return_algo_index is false: + // e.g., in case of legacy cublas tuning. + if (!return_algo_index) return best; + // Otherwise, map a real algorithm ID to its index among the results. + for (size_t i = 0; i < results.size(); ++i) { + if (best.gemm().algorithm() == results[i].gemm().algorithm()) { + best.mutable_gemm()->set_algorithm(i); + return best; + } + } + return Internal("unknown best algorithm"); } - } - return best_algorithm; -} - -#endif + LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance " + "might be suboptimal: " + << best_res.status(); + return tensorflow::AutotuneResult{}; + } // GetBestAlgorithm +}; // GemmAutotuner // Do Gemm Autotune without stream executor. Use results from autotune cache // only. StatusOr RunOnInstruction(HloInstruction* gemm, const AutotuneConfig& config) { VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString(); + TF_ASSIGN_OR_RETURN(auto backend_config, + gemm->backend_config()); - AutotuneCacheKey key(config.GetModelStr(), *gemm); - - TF_ASSIGN_OR_RETURN(AutotuneResult algorithm, - AutotunerUtil::Autotune(gemm, config, [&] { - return DoGemmAutotuneNoCache(gemm, key, config); - })); - - se::CudaComputeCapability capability = config.GetCudaComputeCapability(); - GemmBackendConfig gemm_config = - gemm->backend_config().value(); - GemmBackendConfig updated_config = gemm_config; - - // We only set the 'algorithm' field on non-Ampere architectures, as for - // Ampere it's ignored in any case. - if (!capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - if (algorithm.has_gemm()) { - updated_config.set_selected_algorithm(algorithm.gemm().algorithm()); - } else { - updated_config.set_selected_algorithm(se::blas::kRuntimeAutotuning); - } + // Degenerate gemms replaced with memzero operation, no need to auto tune it. + if (backend_config.alpha_real() == 0.0 && + backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { + VLOG(3) << "Skip degenerate gemm instruction auto tuning"; + return false; } - TF_RETURN_IF_ERROR(gemm->set_backend_config(updated_config)); - return updated_config.SerializeAsString() != gemm_config.SerializeAsString(); + + TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm)); + + GemmAutotuner autotuner(config); + TF_ASSIGN_OR_RETURN(auto new_algorithm, + AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config, + [&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm, gemm_config)); + return algo.has_gemm() ? algo.gemm().algorithm() : se::blas::kDefaultAlgorithm; + })); + + auto old_algorithm = backend_config.selected_algorithm(); + if (new_algorithm == old_algorithm) { + // We don't need to update the backend config if + // the algorithm hasn't changed unless previously + // the algorithm wasn't set explicitly. + return false; + } + + backend_config.set_selected_algorithm(new_algorithm); + TF_RETURN_IF_ERROR(gemm->set_backend_config(backend_config)); + return true; // We changed `gemm` } StatusOr RunOnComputation(HloComputation* computation, - AutotuneConfig config) { + AutotuneConfig config) { bool changed = false; + for (HloInstruction* instr : computation->instructions()) { - if (IsCublasGemm(*instr)) { + //if (IsCublasGemm(*instr)) { + if (IsCublasLtMatmul(*instr)) { // NOTE: legacy cublas autotuning is NYI ! TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, config)); changed |= result; } @@ -401,9 +400,24 @@ StatusOr RunOnComputation(HloComputation* computation, } // namespace -StatusOr GemmAlgorithmPicker::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { +StatusOr GemmAlgorithmPicker::RunStandalone( + const se::gpu::GemmConfig& cfg, + std::vector< Shape >&& input_shapes, const Shape& output_shape, + const DebugOptions& debug_options) { + + GemmAutotuner autotuner(config_); + GemmConfig gemm_config{cfg}; + + return AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config_, + [&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm_config, std::move(input_shapes), + output_shape, debug_options)); + return algo.has_gemm() ? algo.gemm().algorithm() : se::blas::kDefaultAlgorithm; + }); +} + +StatusOr GemmAlgorithmPicker::Run(HloModule* module, + const absl::flat_hash_set& threads) { XLA_SCOPED_LOGGING_TIMER( absl::StrCat("GemmAlgorithmPicker for ", module->name())); @@ -414,7 +428,7 @@ StatusOr GemmAlgorithmPicker::Run( bool changed = false; for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { + module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_)); changed |= result; } @@ -422,4 +436,4 @@ StatusOr GemmAlgorithmPicker::Run( } } // namespace gpu -} // namespace xla +} // namespace xla \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h index 9675d8968bfebf..4bde40470a7399 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h @@ -17,40 +17,36 @@ limitations under the License. #include #include -#include #include -#include - -#include "absl/strings/string_view.h" -#include "xla/autotune_results.pb.h" -#include "xla/autotuning.pb.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" - -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" -#endif + +//#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" + +#include "tensorflow/tsl/protobuf/autotuning.pb.h" + +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" + +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" + +namespace stream_executor { +namespace gpu { + struct GemmConfig; +} + +} namespace xla { namespace gpu { -StatusOr GetBestBlasAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, - const Shape& output_shape, const HloModuleConfig& hlo_module_config, - double beta, - const std::function( - const se::blas::AlgorithmType&)>& run_benchmark); // GemmAlgorithmPicker supports two modes: device and deviceless. // In device mode, we run autotuning on the device and store autotune results. @@ -59,14 +55,22 @@ StatusOr GetBestBlasAlgorithm( // autotune result is not stored, then algorithm is set to kRuntimeAutotuning. class GemmAlgorithmPicker : public HloModulePass { public: - explicit GemmAlgorithmPicker(AutotuneConfig config) : config_(config) {} + explicit GemmAlgorithmPicker(AutotuneConfig config): config_(config) {} absl::string_view name() const override { return "gemm-algorithm-picker"; } + const AutotuneConfig& config() const { + return config_; + } + using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; + StatusOr Run(HloModule* module, + const absl::flat_hash_set& threads) override; + + StatusOr RunStandalone( + const se::gpu::GemmConfig& gemm_config, + std::vector< Shape >&& input_shapes, const Shape& output_shape, + const DebugOptions& debug_options); private: AutotuneConfig config_; @@ -75,4 +79,4 @@ class GemmAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ +#endif // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc index 7676f54f8aab85..cdeb30448a724f 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc @@ -65,18 +65,18 @@ namespace m = match; Status SetName(HloModule *module, HloInstruction *gemm) { if (IsCublasLtMatmul(*gemm)) { module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul"); - return OkStatus(); + return absl::OkStatus(); } - GemmBackendConfig config; - TF_ASSIGN_OR_RETURN(config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto config, + gemm->backend_config()); const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() || !dot_dims.rhs_batch_dimensions().empty(); module->SetAndUniquifyInstrName( gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm"); - return OkStatus(); + return absl::OkStatus(); } // Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue @@ -366,16 +366,6 @@ auto GemmOrCublasLtMatmul(HloInstruction **instr) { return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget}); } -auto CublasLtMatmulMaybeF8(HloInstruction **instr) { - return m::CustomCall( - instr, {kCublasLtMatmulCallTarget, kCublasLtMatmulF8CallTarget}); -} - -auto GemmOrCublasLtMatmulMaybeF8(HloInstruction **instr) { - return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget, - kCublasLtMatmulF8CallTarget}); -} - auto BcastConstScalar(HloInstruction **instr, double value) { return m::Broadcast(instr, m::ConstantScalar(value)); } @@ -431,6 +421,12 @@ auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) { std::move(pattern)); } +template +auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) { + return m::AnyOf(m::Bitcast(optional_bitcast, pattern), + std::move(pattern)); +} + // The rewriting proceeds in a bottom-up way: // // (kDot A B) is rewritten into a (kCustomCall:gemm A B) @@ -1359,8 +1355,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *bias = broadcast->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); - + TF_ASSIGN_OR_RETURN(auto config, + gemm->backend_config()); // # output column dims == # non-contracting rhs operand dims. const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); size_t num_col_dims = gemm->operand(1)->shape().rank() - @@ -1518,7 +1514,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX : GemmBackendConfig::BIAS_GELU); } else { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr output = gemm->CloneWithNewShape( @@ -1661,7 +1657,149 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128, PrimitiveType::C128, DataType::kComplexDouble}, - }}; + }; + + return absl::c_linear_search( + supported_type_combinations, + std::make_tuple(compute_type, scale_type, a_dtype, b_dtype, + output_dtype)); + } + + StatusOr TypesAreSupportedByCublasLt( + const HloInstruction &instr, const GemmBackendConfig &backend_config, + const HloInstruction *bias = nullptr) const { + // Figure out the Atype/Btype. + const PrimitiveType a_dtype = instr.operand(0)->shape().element_type(); + const PrimitiveType b_dtype = instr.operand(1)->shape().element_type(); + const PrimitiveType output_type = + bias ? bias->shape().element_type() : instr.shape().element_type(); + const std::array supported_type = { + PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, + PrimitiveType::S8, PrimitiveType::F16, + PrimitiveType::BF16, PrimitiveType::F32, + PrimitiveType::S32, PrimitiveType::F64, + PrimitiveType::C64, PrimitiveType::C128}; + if (!absl::c_linear_search(supported_type, output_type)) return false; + // cublasLt has a defined set of combinations of types that it supports. + // Figure out the computeType and scaleType. + TF_ASSIGN_OR_RETURN(auto output_dtype, se::gpu::AsBlasDataType(output_type)); + TF_ASSIGN_OR_RETURN(auto blas_a_dtype, se::gpu::AsBlasDataType(a_dtype)); + + const int max_precision = *absl::c_max_element( + backend_config.precision_config().operand_precision()); + + TF_ASSIGN_OR_RETURN( + auto compute_type, + se::gpu::GetBlasComputationType( + blas_a_dtype, output_dtype, max_precision)); + se::blas::DataType scale_type = + se::gpu::GetScaleType(output_dtype, compute_type); + + using se::blas::ComputationType; + using se::blas::DataType; + using TypeCombinations = std::initializer_list>; + // This matrix of supported types is taken directly from cublasLt + // documentation. + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul + const TypeCombinations supported_cublas_type_combinations = { + // FP8 types: + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kBF16}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kHalf}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kFloat}, + + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kBF16}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kF8E4M3FN}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kF8E5M2}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kHalf}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kFloat}, + + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kBF16}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kF8E5M2}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kHalf}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kFloat}, + // There would be an entry here for A/BType complex int8, but we do + // not support that type. + {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64, + PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kF16AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + // The next 4 may be supported by hipblaslt, but they are not + // covered by any unit tests + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kBF16AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kTF32AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64, + PrimitiveType::F64, DataType::kDouble}, + {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128, + PrimitiveType::C128, DataType::kComplexDouble}, + }; + if (IsCuda(gpu_version_) && + absl::c_linear_search(supported_cublas_type_combinations, + std::tuple{compute_type, scale_type, a_dtype, + b_dtype, output_dtype})) { + return true; + } + const TypeCombinations supported_type_combinations = { + // Other data types: + + {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8, + PrimitiveType::S8, DataType::kInt32}, + {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8, + PrimitiveType::S8, DataType::kInt8}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8, + PrimitiveType::S8, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kFloat}, + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kBF16}, + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kFloat}, + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kHalf}, + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + }; return absl::c_linear_search( supported_type_combinations, @@ -1805,6 +1943,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const DotDimensionNumbers &dot_dims = gemm_backend_config.dot_dimension_numbers(); + // We use ALG_UNSET and kDefaultComputePrecision because we don't care about + // the precision, just the layout, since we're just checking if the matrix + // is column-major. TF_ASSIGN_OR_RETURN( GemmConfig gemm_config, GemmConfig::For( diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.h b/third_party/xla/xla/service/gpu/gemm_rewriter.h index 776a3756e9e20a..0fa75e8b1657dc 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.h +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.h @@ -36,9 +36,10 @@ using ComputeCap = se::RocmComputeCapability; // (kMultiply (kDot A B) alpha) // (kMultiply C beta)) // -// where A, B, C are matrixes and `alpha` and `beta` are host constants. -// The additional requirement is that C has no other users (otherwise, -// it does not make sense to fuse it inside the custom call). +// where A, B, C are matrices or vectors and `alpha` and `beta` are host +// constants. In matrix-vector multiplication, one operand must be a matrix and +// the other must be a vector. The additional requirement is that C has no other +// users (otherwise, it does not make sense to fuse it inside the custom call). // // Both multiplication and addition can be avoided (equivalent to setting // `alpha` to one and `beta` to zero). diff --git a/third_party/xla/xla/service/gpu/gemm_thunk.cc b/third_party/xla/xla/service/gpu/gemm_thunk.cc index 12017a6b0044aa..cd968b0d39b1e9 100644 --- a/third_party/xla/xla/service/gpu/gemm_thunk.cc +++ b/third_party/xla/xla/service/gpu/gemm_thunk.cc @@ -43,7 +43,9 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) { const BufferAllocations& allocs = *params.buffer_allocations; return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), allocs.GetDeviceAddress(rhs_buffer_), - allocs.GetDeviceAddress(output_buffer_), deterministic_, + allocs.GetDeviceAddress(output_buffer_), + workspace_buffer, + deterministic_, params.stream); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index c97ee7b8441fb7..f87db9750082cd 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1071,6 +1071,10 @@ StatusOr> GpuCompiler::RunHloPasses( [&] { return absl::StrCat("HLO Transforms:", module->name()); }, tsl::profiler::TraceMeLevel::kInfo); + const DebugOptions& debug_opts = module->config().debug_options(); + auto cfg = GetAutotuneConfig(stream_exec, debug_opts, nullptr); + TF_RETURN_IF_ERROR(AutotunerUtil::LoadAutotuneResultsFromFileOnce(cfg)); + GpuTargetConfig gpu_target_config = GetGpuTargetConfig(stream_exec); TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), stream_exec, options, gpu_target_config, diff --git a/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc new file mode 100644 index 00000000000000..6cdb7d970328d0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc @@ -0,0 +1,184 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "xla/service/gpu/gpublas_lt_matmul_thunk.h" +#include "xla/debug_options_flags.h" +//#include "xla/service/gpu/autotuner_util.h" + +namespace xla { +namespace gpu { + +struct MatmulPlanCache { + + static MatmulPlanCache& i(const se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::vector< std::unique_ptr< MatmulPlanCache > > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (meta.size() < dev_id) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) res.reset(new MatmulPlanCache()); + return *res; + } + + template < class Func > + StatusOr + GetOrCreate(const std::string& key, Func&& create) { + // each GPU has a different mutex => hence different GPU instances can + // create matmul plans in parallel + absl::MutexLock lock(mutex_.get()); + auto res = map_.emplace(key, se::gpu::BlasLt::MatmulPlanPtr{}); + if(res.second) { // new entry inserted + TF_ASSIGN_OR_RETURN(res.first->second, create()); + } + return res.first->second.get(); + } + +private: + MatmulPlanCache() : mutex_(std::make_unique< absl::Mutex >()) { } + +private: + std::unique_ptr< absl::Mutex > mutex_; + absl::flat_hash_map map_; +}; + +CublasLtMatmulThunk::CublasLtMatmulThunk( + ThunkInfo thunk_info, GemmConfig config, + BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, + BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, + BufferAllocation::Slice bias_buffer /* may be null */, + BufferAllocation::Slice aux_buffer /* may be null */, + BufferAllocation::Slice a_scale /* may be null */, + BufferAllocation::Slice b_scale /* may be null */, + BufferAllocation::Slice c_scale /* may be null */, + BufferAllocation::Slice d_scale /* may be null */, + BufferAllocation::Slice d_amax /* may be null */, + absl::optional workspace) + : Thunk(Kind::kCublasLtMatmul, thunk_info), + gemm_config_(std::move(config)), + a_buffer_(a_buffer), + b_buffer_(b_buffer), + c_buffer_(c_buffer), + d_buffer_(d_buffer), + bias_buffer_(bias_buffer), + aux_buffer_(aux_buffer), + a_scale_buffer_(a_scale), + b_scale_buffer_(b_scale), + c_scale_buffer_(c_scale), + d_scale_buffer_(d_scale), + d_amax_buffer_(d_amax), + workspace_buffer_(workspace) +{ + canonical_hlo_ = se::gpu::ToCSVString(gemm_config_, /*full_string*/true); + // set algorithm ID explicitly to -1 if tuning is disabled! + + if(GetDebugOptionsFromFlags().xla_gpu_autotune_level() == 0) { + gemm_config_.algorithm = se::blas::kDefaultAlgorithm; + } +} + +Status CublasLtMatmulThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + if (!executor->AsBlas()) { + return InternalError("Failed to initialize BLASLT support"); + } + return OkStatus(); +} + +Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { + + TF_ASSIGN_OR_RETURN(auto *plan, GetCachedMatmulPlan(params)); + + VLOG(2) << params.stream->parent()->device_ordinal() << + ": cublas_lt_matmul for: " << canonical_hlo_; + const BufferAllocations& allocs = *params.buffer_allocations; + + se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; + if (bias_buffer_.allocation() != nullptr) { + bias = allocs.GetDeviceAddress(bias_buffer_); + } + if (a_scale_buffer_.allocation() != nullptr) { + a_scale = allocs.GetDeviceAddress(a_scale_buffer_); + } + if (b_scale_buffer_.allocation() != nullptr) { + b_scale = allocs.GetDeviceAddress(b_scale_buffer_); + } + if (c_scale_buffer_.allocation() != nullptr) { + c_scale = allocs.GetDeviceAddress(c_scale_buffer_); + } + if (d_scale_buffer_.allocation() != nullptr) { + d_scale = allocs.GetDeviceAddress(d_scale_buffer_); + } + if (d_amax_buffer_.allocation() != nullptr) { + d_amax = allocs.GetDeviceAddress(d_amax_buffer_); + } + + se::DeviceMemoryBase aux; + if (aux_buffer_.allocation() != nullptr) { + aux = allocs.GetDeviceAddress(aux_buffer_); + } + + absl::optional workspace; + if (workspace_buffer_) { + workspace = allocs.GetDeviceAddress(*workspace_buffer_); + } + + return plan->ExecuteOnStream( + params.stream, allocs.GetDeviceAddress(a_buffer_), + allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), + allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, + d_scale, d_amax, workspace); +} + +auto CublasLtMatmulThunk::GetCachedMatmulPlan( + const ExecuteParams& params) -> StatusOr { + + auto& cache = MatmulPlanCache::i(params.stream); + + auto create = [&]() -> StatusOr { + VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream << + " cfg: " << canonical_hlo_; + + TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + params.stream, gemm_config_)); + + int64_t max_workspace = workspace_buffer_.has_value() + ? workspace_buffer_.value().size() : 0, + algorithm_id = gemm_config_.algorithm; + int64_t num_algorithms = algorithm_id == se::blas::kDefaultAlgorithm ? + 1 : se::gpu::BlasLt::kMaxAlgorithms; + TF_ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms, max_workspace)); + + if (algorithm_id == se::blas::kDefaultAlgorithm && !algorithms.empty()) { + algorithm_id = algorithms[0].id; + } + for(const auto& alg : algorithms) { + if (alg.id == algorithm_id) { + TF_RETURN_IF_ERROR(plan->SetAlgorithm(alg)); + return std::move(plan); + } + } + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[0])); + LOG(WARNING) << "Wrong algorithm ID: " << algorithm_id << " use default instead."; + return std::move(plan); + }; + return cache.GetOrCreate(canonical_hlo_, create); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h new file mode 100644 index 00000000000000..ea828e705b09cc --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h @@ -0,0 +1,73 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ +#define TENSORFLOW_COMPILER_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace gpu { + +class CublasLtMatmulThunk : public Thunk { + public: + + CublasLtMatmulThunk( + ThunkInfo thunk_info, GemmConfig config, + BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, + BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, + BufferAllocation::Slice bias_buffer /* may be null */, + BufferAllocation::Slice aux_buffer /* may be null */, + BufferAllocation::Slice a_scale_buffer /* may be null */, + BufferAllocation::Slice b_scale_buffer /* may be null */, + BufferAllocation::Slice c_scale_buffer /* may be null */, + BufferAllocation::Slice d_scale_buffer /* may be null */, + BufferAllocation::Slice d_amax_buffer /* may be null */, + absl::optional workspace_buffer); + + Status ExecuteOnStream(const ExecuteParams& params) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + + private: + StatusOr GetCachedMatmulPlan( + const ExecuteParams& params); + + GemmConfig gemm_config_; + std::string canonical_hlo_; + BufferAllocation::Slice a_buffer_; + BufferAllocation::Slice b_buffer_; + BufferAllocation::Slice c_buffer_; + BufferAllocation::Slice d_buffer_; + BufferAllocation::Slice bias_buffer_; + BufferAllocation::Slice aux_buffer_; + BufferAllocation::Slice a_scale_buffer_; + BufferAllocation::Slice b_scale_buffer_; + BufferAllocation::Slice c_scale_buffer_; + BufferAllocation::Slice d_scale_buffer_; + BufferAllocation::Slice d_amax_buffer_; + absl::optional workspace_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 12264198a0c737..a69ac468a26e58 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -53,6 +53,11 @@ bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 2; } +// Return whether the given shape is rank 1 excluding the batch dimensions. +bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { + return shape.rank() == batch_dimensions_size + 1; +} + Shape GetShapeFromTensorType(mlir::Value value) { constexpr char kDefaultLayoutAttrName[] = "xla_shape"; @@ -71,6 +76,7 @@ Shape GetShapeFromTensorType(mlir::Value value) { } // namespace + bool IsMatrixMultiplication(const HloInstruction& dot) { if (dot.opcode() != HloOpcode::kDot) { return false; @@ -82,9 +88,10 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || - output_primitive_type == F16 || output_primitive_type == BF16 || - output_primitive_type == F32 || output_primitive_type == F64 || - output_primitive_type == C64 || output_primitive_type == C128) || + output_primitive_type == F16 || + output_primitive_type == BF16 || output_primitive_type == F32 || + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128) || (output_primitive_type == S32 && lhs_shape.element_type() == S8 && rhs_shape.element_type() == S8); bool shapes_are_valid = @@ -108,6 +115,36 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { return true; } +bool IsMatrixVectorMultiplication(const HloInstruction& dot) { + if (dot.opcode() != HloOpcode::kDot) { + return false; + } + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + PrimitiveType output_primitive_type = dot.shape().element_type(); + bool type_is_allowed = + (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || + output_primitive_type == F16 || output_primitive_type == BF16 || + output_primitive_type == F32 || output_primitive_type == F64 || + output_primitive_type == C64 || output_primitive_type == C128) || + (output_primitive_type == S32 && lhs_shape.element_type() == S8 && + rhs_shape.element_type() == S8); + + bool shapes_are_valid = + type_is_allowed && + ((IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) && + IsRank1(rhs_shape, dim_numbers.lhs_batch_dimensions_size())) || + (IsRank1(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) && + IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()))) && + IsRank1(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) && + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); + + return shapes_are_valid; +} + const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky"; bool IsCustomCallToCusolver(const HloInstruction& hlo) { diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index a7daa4ccaadcfd..2baeabfa8ff501 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -44,6 +44,7 @@ inline constexpr int64_t kMinTotalDimensionsToTransposeTiled = 64 * 128; // This function should never return "true" on instructions after // GemmRewriter pass has finished. bool IsMatrixMultiplication(const HloInstruction& dot); +bool IsMatrixVectorMultiplication(const HloInstruction& dot); inline constexpr int64_t WarpSize() { return 32; } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index f628870b99e015..48d724bdbab9c0 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -906,18 +906,19 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias())); } - BufferAllocation::Slice aux; + BufferAllocation::Slice aux, workspace; if (matmul.getAux() != nullptr) { TF_ASSIGN_OR_RETURN(aux, GetAllocationSlice(matmul.getAux())); } + if (matmul.getWorkspace() != nullptr) { + TF_ASSIGN_OR_RETURN(workspace, GetAllocationSlice(matmul.getWorkspace())); + } - TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue epilogue, - cublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); + TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(matmul)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(gemm_config), - epilogue, matmul.getAlgorithm(), a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax); + GetThunkInfo(op), std::move(config), a, b, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, + workspace); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index dd852b44426394..ac34d289d16418 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,47 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/matmul_utils.h" +#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include +#include #include #include +#include +#include #include -#include #include #include -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/primitive_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/numeric_options.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/host_or_device_scalar.h" -#include "tsl/platform/tensor_float_32_utils.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/status_macros.h" namespace xla { namespace gpu { +using se::blas::DataType; +using se::gpu::BlasLt; + StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims) { @@ -81,32 +67,33 @@ const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( return dimension_numbers.rhs_batch_dimensions(); } -int64_t ContractingDimensionIndex(const HloInstruction& dot, - const int operand_number) { +StatusOr ContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); if (operand_number == 0) { - CHECK_EQ(dimension_numbers.lhs_contracting_dimensions().size(), 1); + TF_RET_CHECK(dimension_numbers.lhs_contracting_dimensions().size() == 1); return dimension_numbers.lhs_contracting_dimensions(0); } - CHECK_EQ(dimension_numbers.rhs_contracting_dimensions().size(), 1); + TF_RET_CHECK(dimension_numbers.rhs_contracting_dimensions().size() == 1); return dimension_numbers.rhs_contracting_dimensions(0); } -int64_t NonContractingDimensionIndex(const HloInstruction& dot, - const int operand_number) { - StatusOr> non_contracting_dims = +StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { + TF_ASSIGN_OR_RETURN(int64_t contracting_dim, + ContractingDimensionIndex(dot, operand_number)); + TF_ASSIGN_OR_RETURN( + std::vector non_contracting_dims, GetNonContractingDims(dot.operand(operand_number)->shape(), BatchDimensionsForOperand(dot, operand_number), - {ContractingDimensionIndex(dot, operand_number)}); - TF_CHECK_OK(non_contracting_dims.status()); - CHECK_EQ(non_contracting_dims->size(), 1); - return non_contracting_dims->front(); + {contracting_dim})); + TF_RET_CHECK(non_contracting_dims.size() == 1); + return non_contracting_dims.front(); } -StatusOr GetBatchRowColumnShape(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims) { +StatusOr GetBatchRowColumnShape( + const Shape& shape, absl::Span batch_dims, + absl::Span row_dims, absl::Span col_dims) { TF_RET_CHECK(shape.has_layout()); std::vector minor_to_major; @@ -114,7 +101,8 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, // The GeMM output always has its layout set such that the batch, row, and // col dim groups are each laid out physically sequentially. GeMM operands // must, therefore, be laid out similarly. - auto check_physically_sequential = [&](absl::Span dims) { + auto check_physically_sequential = + [&](absl::Span dims) -> Status { for (auto it = dims.rbegin(); it != dims.rend(); ++it) { // NOTE: `i` is incremented as we check the dimensions. if (*it != shape.layout().minor_to_major()[i++]) @@ -163,7 +151,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, int64_t num_rows = shape.dimensions(1); int64_t num_cols = shape.dimensions(2); - MatrixLayout::Order order = MatrixLayout::Order::kRowMajor; + Order order{Order::kRowMajor}; int64_t leading_dim_stride = num_cols; int64_t batch_stride = num_rows * num_cols; @@ -174,7 +162,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, case 012: // (B,R,C) (major-to-minor) break; case 021: // (B,C,R) - order = MatrixLayout::Order::kColumnMajor; + order = Order::kColumnMajor; leading_dim_stride = num_rows; break; case 0102: // (R,B,C) @@ -182,7 +170,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, batch_stride = num_cols; break; case 0201: // (C,B,R) - order = MatrixLayout::Order::kColumnMajor; + order = Order::kColumnMajor; leading_dim_stride = batch_size * num_rows; batch_stride = num_rows; break; @@ -190,11 +178,13 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return Unimplemented("batch in most minor dimension"); } - if (batch_size == 1) batch_stride = 0; - return MatrixLayout{ - shape.element_type(), num_rows, num_cols, order, - leading_dim_stride, batch_size, batch_stride, - }; + if (batch_size == 1) { + batch_stride = 0; + } + + TF_ASSIGN_OR_RETURN(auto dtype, se::gpu::AsBlasDataType(shape.element_type())); + return MatrixLayout(dtype, num_rows, num_cols, order, + batch_size, leading_dim_stride, batch_stride); } /*static*/ StatusOr MatrixLayout::For( @@ -206,11 +196,9 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return MatrixLayout::For(batch_row_col_shape); } -/*static*/ StatusOr MatrixLayout::For(const Shape& shape, - size_t lhs_num_batch_dims, - size_t lhs_num_row_dims, - size_t rhs_num_batch_dims, - size_t rhs_num_col_dims) { +/*static*/ StatusOr MatrixLayout::For( + const Shape& shape, size_t lhs_num_batch_dims, size_t lhs_num_row_dims, + size_t rhs_num_batch_dims, size_t rhs_num_col_dims) { size_t num_batch_dims = std::max(lhs_num_batch_dims, rhs_num_batch_dims); TF_RET_CHECK(shape.rank() == @@ -227,11 +215,6 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return MatrixLayout::For(shape, batch_dims, row_dims, col_dims); } -void MatrixLayout::Transpose() { - std::swap(num_rows, num_cols); - order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor; -} - namespace { // Returns the relative order of 'dims' as indices from 0 to dims.size() - 1. // Let 'indices' be the returned vector, then it holds that @@ -248,7 +231,10 @@ std::vector NormalizedRelativeOrder(absl::Span dims) { } // namespace StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, - int64_t operand_idx) { + int64_t operand_idx) { + // if (Cast(&dot)->sparse_operands()) { + // return false; + // } TF_RET_CHECK(dot.opcode() == HloOpcode::kDot); TF_RET_CHECK(dot.operand_count() > operand_idx); @@ -298,23 +284,9 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision, - bool grad_x, bool grad_y) { - return GemmConfig::For(lhs_shape, lhs_batch_dims, lhs_contracting_dims, - rhs_shape, rhs_batch_dims, rhs_contracting_dims, - /*c_shape=*/output_shape, /*bias_shape_ptr=*/nullptr, - output_shape, alpha_real, alpha_imag, beta, algorithm, - compute_precision, grad_x, grad_y); -} + int64_t algorithm, int64_t compute_precision, + BlasLt::Epilogue epilogue) { -/*static*/ StatusOr GemmConfig::For( - const Shape& lhs_shape, absl::Span lhs_batch_dims, - absl::Span lhs_contracting_dims, const Shape& rhs_shape, - absl::Span rhs_batch_dims, - absl::Span rhs_contracting_dims, const Shape& c_shape, - const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real, - double alpha_imag, double beta, std::optional algorithm, - int64_t compute_precision, bool grad_x, bool grad_y) { absl::Span lhs_col_dims = lhs_contracting_dims; TF_ASSIGN_OR_RETURN( std::vector lhs_row_dims, @@ -352,18 +324,7 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, TF_ASSIGN_OR_RETURN(MatrixLayout output_layout, MatrixLayout::For(output_shape, output_batch_dims, output_row_dims, output_col_dims)); - Shape c_matrix_shape = c_shape; - if (primitive_util::IsF8Type(lhs_shape.element_type()) && - primitive_util::IsF8Type(output_shape.element_type()) && beta == 0.0) { - // By default, if c is not present (i.e., beta is 0), c_shape will be the - // output shape. cublasLT requires a valid c_shape to be passed, even if c - // is not present, and normally setting it to the output shape is fine. But - // for matmuls with FP8 inputs and outputs, C must instead have the same - // dtype as the vector bias if present, and either BF16 or F16 otherwise. So - // we set the dtype of C here. - c_matrix_shape.set_element_type( - bias_shape_ptr != nullptr ? bias_shape_ptr->element_type() : BF16); - } + Shape c_matrix_shape = output_shape; TF_ASSIGN_OR_RETURN(MatrixLayout c_layout, MatrixLayout::For(c_matrix_shape, output_batch_dims, @@ -373,12 +334,9 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, // non-contracting dimensions match in size and relative physical location. // TODO(philipphack): Check the remaining dimensions in the FP8 case once // cuBLASLt supports the NN configuration. - if (lhs_shape.element_type() != F8E4M3FN && - lhs_shape.element_type() != F8E5M2) { - TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows); - TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows); - TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols); - } + TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows); + TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows); + TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols); TF_RET_CHECK(c_layout.num_rows == output_layout.num_rows); TF_RET_CHECK(c_layout.num_cols == output_layout.num_cols); TF_RET_CHECK((lhs_layout.batch_size == output_layout.batch_size) || @@ -387,8 +345,6 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, (rhs_layout.batch_size == 1)); switch (output_shape.element_type()) { - case F8E4M3FN: - case F8E5M2: case F16: case BF16: case F32: @@ -400,40 +356,38 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, break; case S32: TF_RET_CHECK(alpha_imag == 0); - if (lhs_layout.dtype != PrimitiveType::S8 || - rhs_layout.dtype != PrimitiveType::S8) { - return InternalError( - "For int32 gemm output only int8 input is supported, got input: " - "%s, %s", - primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype)); + if (lhs_layout.dtype != DataType::kInt8 || + rhs_layout.dtype != DataType::kInt8) { + return Internal( + "For int32 gemm output only int8 input is supported !"); } break; default: - return InternalError("Unexpected GEMM datatype: %s", - primitive_util::LowercasePrimitiveTypeName( - output_shape.element_type())); + return Internal("Unexpected GEMM datatype: %s", + primitive_util::LowercasePrimitiveTypeName( + output_shape.element_type())); } - return GemmConfig{ - lhs_layout, - rhs_layout, - c_layout, - output_layout, - {alpha_real, alpha_imag}, - beta, - algorithm, - compute_precision, - grad_x, - grad_y - }; + return GemmConfig( + se::gpu::GemmConfig{ + .lhs_layout = lhs_layout, + .rhs_layout = rhs_layout, + .c_layout = c_layout, + .output_layout = output_layout, + .alpha = {alpha_real, alpha_imag}, + .beta = beta, + .algorithm = algorithm, + .compute_precision = compute_precision, + .epilogue = epilogue}); } -/*static*/ StatusOr GemmConfig::For(const HloInstruction* gemm) { - TF_ASSIGN_OR_RETURN(GemmBackendConfig config, +/*static*/ StatusOr GemmConfig::For( + const HloInstruction* gemm) { + + TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); - std::optional algorithm; + int64_t algorithm = se::blas::kDefaultAlgorithm; if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) { algorithm = config.selected_algorithm(); } @@ -444,45 +398,37 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, const Shape& output_shape = gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); - auto attributes = gemm->frontend_attributes().map(); - bool grad_x = (attributes["grad_x"] == "true"); - bool grad_y = (attributes["grad_y"] == "true"); + int64_t precision = se::blas::kDefaultComputePrecision; + for (auto operand_precision : config.precision_config().operand_precision()) { + precision = std::max(precision, static_cast(operand_precision)); + } + TF_ASSIGN_OR_RETURN(auto epilogue, + gpublas_lt::AsBlasLtEpilogue(config.epilogue())); return GemmConfig::For( lhs_shape, dot_dims.lhs_batch_dimensions(), dot_dims.lhs_contracting_dimensions(), rhs_shape, - dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(), + dot_dims.rhs_batch_dimensions(), + dot_dims.rhs_contracting_dimensions(), output_shape, config.alpha_real(), config.alpha_imag(), config.beta(), - algorithm, se::blas::kDefaultComputePrecision, grad_x, grad_y); + algorithm, precision, epilogue); } /*static*/ StatusOr GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - std::optional algorithm; + + auto dot_dims = op.getDotDimensionNumbers(); + int64_t algorithm = se::blas::kDefaultAlgorithm; if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); - - bool grad_x = false; - bool grad_y = false; - auto attr_grad_x = op.getGradX(); - if (attr_grad_x) - grad_x = attr_grad_x.value(); - auto attr_grad_y = op.getGradY(); - if (attr_grad_y) - grad_y = attr_grad_y.value(); - + int64_t compute_precision = 0; // Default if (op.getPrecisionConfig().has_value()) { auto precision_config = op.getPrecisionConfig(); for (auto attr : precision_config.value()) { int64_t value = static_cast( attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } + compute_precision = std::max(value, compute_precision); } } - return GemmConfig::For( GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), @@ -490,292 +436,276 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), algorithm, compute_precision, - grad_x, grad_y); -} - -StatusOr GetBlasComputationType( - PrimitiveType lhs_dtype, PrimitiveType output_dtype, - int64_t compute_precision) { - switch (output_dtype) { - case F8E5M2: // fall-through - case F8E4M3FN: // fall-through - case F16: // fall-through - case BF16: - // Accumulate in f32 precision. - return se::blas::ComputationType::kF32; - case F32: // fall-through - case C64: -#if GOOGLE_CUDA - if (tsl::tensor_float_32_execution_enabled() && compute_precision <= 1 && - lhs_dtype == output_dtype) { - // CublasLt requires compute type to be F32 for F8 matmul. - // TF32 should only be chosen for FP32 or C64 gemm - return se::blas::ComputationType::kTF32AsF32; - } -#endif - return se::blas::ComputationType::kF32; - case F64: // fall-through - case C128: - return se::blas::ComputationType::kF64; - case S32: - return se::blas::ComputationType::kI32; - default: - return InternalError("GetBlasComputationType: unsupported type"); - } + BlasLt::Epilogue::kDefault); } -namespace cublas_lt { +/*static*/ StatusOr GemmConfig::For( + mlir::lmhlo_gpu::CublasLtMatmulOp op) { -se::blas::DataType GetScaleType(se::blas::DataType c_type, - se::blas::ComputationType computation_type) { - return ((computation_type == se::blas::ComputationType::kF32) && - (c_type != se::blas::DataType::kComplexFloat)) - ? se::blas::DataType::kFloat - : c_type; + auto dot_dims = op.getDotDimensionNumbers(); + int64_t algorithm = op.getAlgorithm(); + TF_ASSIGN_OR_RETURN(auto epilogue, gpublas_lt::AsBlasLtEpilogue(op.getEpilogue())); + + int64_t compute_precision = 0; // Default + if (op.getPrecisionConfig().has_value()) { + auto precision_config = op.getPrecisionConfig(); + for (auto attr : precision_config.value()) { + int64_t value = static_cast( + attr.template cast().getValue()); + compute_precision = std::max(value, compute_precision); + } + } + return GemmConfig::For( + GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), + dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), + dot_dims.getRhsBatchingDimensions(), + dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), + op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), + op.getBeta().convertToDouble(), algorithm, compute_precision, + epilogue); } -} // namespace cublas_lt - -namespace { +StatusOr GemmConfig::GetMatrixDescriptors( + se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, + se::DeviceMemoryBase out_buf) const { + auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout, + se::DeviceMemoryBase data) { + return se::gpu::MatrixDescriptor{ + data, layout.leading_dim_stride, layout.batch_stride, + layout.dtype, + // BLAS is column-major by default. + (layout.order == se::gpu::MatrixLayout::Order::kColumnMajor + ? se::blas::Transpose::kNoTranspose + : se::blas::Transpose::kTranspose)}; + }; + // TODO: make a local copy to prevent modification of layouts, + // but maybe we can modify them once instead during creation ? + se::gpu::MatrixLayout lhs = lhs_layout, rhs = rhs_layout, out = output_layout; -// This struct contains the metadata of a matrix, e.g., its base address and -// dimensions. -struct MatrixDescriptor { - se::DeviceMemoryBase data; - int64_t leading_dim_stride; - int64_t batch_stride; - se::blas::Transpose transpose; - - template - se::DeviceMemory cast() const { - return se::DeviceMemory(data); - } -}; - -// BLAS GeMM's output is column-major. If we require row-major, use identity: -// C^T = (A @ B)^T = B^T @ A^T. -bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, - MatrixLayout& output) { - bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; - if (swap_operands) { - std::swap(lhs, rhs); - lhs.Transpose(); - rhs.Transpose(); - output.Transpose(); + bool must_swap_operands = MakeOutputColumnMajor(lhs, rhs, out); + if (must_swap_operands) { + std::swap(lhs_buf, rhs_buf); } - return swap_operands; -} -bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, - MatrixLayout& output, MatrixLayout& c) { - bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; - if (swap_operands) { - std::swap(lhs, rhs); - rhs.Transpose(); - lhs.Transpose(); - c.Transpose(); - output.Transpose(); - } - return swap_operands; -} + se::gpu::OutputMatrixDescriptor out_desc = create_matrix_desc(out, out_buf); + out_desc.batch_size = out.batch_size; + out_desc.m = out.num_rows; + out_desc.n = out.num_cols; + out_desc.k = lhs.num_cols; + // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // algorithm) here? Why do we use the default? + TF_ASSIGN_OR_RETURN(out_desc.compute_type, + se::gpu::GetBlasComputationType( + lhs.dtype, out.dtype, -1)); -se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) { - // BLAS is column-major by default. - return (order == MatrixLayout::Order::kColumnMajor) - ? se::blas::Transpose::kNoTranspose - : se::blas::Transpose::kTranspose; -} + se::gpu::MatrixDescriptor lhs_desc = create_matrix_desc(lhs, lhs_buf), + rhs_desc = create_matrix_desc(rhs, rhs_buf); -MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout, - se::DeviceMemoryBase data) { - return { - data, - layout.leading_dim_stride, - layout.batch_stride, - AsBlasTranspose(layout.order), - }; + return DescriptorsTuple{lhs_desc, rhs_desc, out_desc, must_swap_operands}; } +namespace { + template -Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, - const MatrixDescriptor& rhs, - const MatrixDescriptor& output, Scale alpha, - Scale beta, se::Stream* stream, - se::blas::AlgorithmType algorithm, - se::blas::ComputePrecision compute_precision, - const se::NumericOptions& numeric_options, - se::blas::ProfileResult* profile_result, - se::blas::CallContext context) { +Status DoGemmWithAlgorithm(const se::gpu::MatrixDescriptor& lhs, + const se::gpu::MatrixDescriptor& rhs, + const se::gpu::OutputMatrixDescriptor& output, + se::DeviceMemoryBase workspace, Scale alpha, + Scale beta, se::Stream* stream, + se::blas::AlgorithmType algorithm, + se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); - PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType(); - PrimitiveType output_type = primitive_util::NativeToPrimitiveType(); TF_ASSIGN_OR_RETURN( se::blas::ComputationType computation_type, - GetBlasComputationType(lhs_type, output_type, compute_precision)); + se::gpu::GetBlasComputationType(lhs.type, + output.type, compute_precision)); se::DeviceMemory output_data(output.data); - if (batch_size != 1) { + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + // Set a workspace for all Blas operations launched below. + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); + + if (output.batch_size != 1) { return stream->ThenBlasGemmStridedBatchedWithAlgorithm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), - rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, - output.leading_dim_stride, output.batch_stride, batch_size, - computation_type, algorithm, numeric_options, profile_result, - context); + lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, + &output_data, output.leading_dim_stride, output.batch_stride, + output.batch_size, computation_type, algorithm, compute_precision, + profile_result, context); } else { return stream->ThenBlasGemmWithAlgorithm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, computation_type, algorithm, - numeric_options, profile_result, context); + lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, + computation_type, algorithm, compute_precision, profile_result, context); } } template -Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, const MatrixDescriptor& rhs, - const MatrixDescriptor& output, Scale alpha, Scale beta, - se::Stream* stream, - std::optional algorithm, - se::blas::ComputePrecision compute_precision, - const se::NumericOptions& numeric_options, - se::blas::ProfileResult* profile_result, - se::blas::CallContext context) { +Status DoGemm(const se::gpu::MatrixDescriptor& lhs, + const se::gpu::MatrixDescriptor& rhs, + const se::gpu::OutputMatrixDescriptor& output, + se::DeviceMemoryBase workspace, Scale alpha, Scale beta, + se::Stream* stream, + std::optional algorithm, + se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); se::DeviceMemory output_data(output.data); -#if GOOGLE_CUDA if (algorithm) { return DoGemmWithAlgorithm( - batch_size, m, n, k, lhs, rhs, output, alpha, beta, stream, *algorithm, - compute_precision, numeric_options, profile_result, context); + lhs, rhs, output, workspace, alpha, beta, stream, + *algorithm, compute_precision, numeric_options, profile_result, + context); } -#endif - if (batch_size != 1) { + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + // Set a workspace for all Blas operations launched below. + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); + + if (output.batch_size != 1) { return stream->ThenBlasGemmStridedBatched( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), - rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, - output.leading_dim_stride, output.batch_stride, batch_size, - numeric_options, context); + lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, + &output_data, output.leading_dim_stride, output.batch_stride, + output.batch_size, compute_precision, context); } - return stream->ThenBlasGemm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, numeric_options, - context); + return stream->ThenBlasGemm(lhs.transpose, rhs.transpose, output.m, + output.n, output.k, alpha, lhs.cast(), + lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, + output.leading_dim_stride, compute_precision, context); } +#define DT(E, T) \ + template <> struct DataTypeToNative { using Type = T; } + +template < DataType E > +struct DataTypeToNative {}; + +DT(kFloat, float); +DT(kDouble, double); +DT(kHalf, Eigen::half); +DT(kInt8, int8_t); +DT(kInt32, int32_t); +DT(kComplexFloat, complex64); +DT(kComplexDouble, complex128); +DT(kBF16, Eigen::bfloat16); + +template < DataType E > +using DataType_t = typename DataTypeToNative< E >::Type; + +#undef DT + } // namespace Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, bool deterministic_ops, - se::Stream* stream, - std::optional algorithm, - se::blas::ProfileResult* profile_result) { + se::DeviceMemoryBase rhs_buffer, + se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase workspace_buffer, + bool deterministic_ops, se::Stream* stream, + std::optional algorithm, + se::blas::ProfileResult* profile_result) { VLOG(2) << "Executing a GemmThunk"; - MatrixLayout lhs_layout = config.lhs_layout; - MatrixLayout rhs_layout = config.rhs_layout; - MatrixLayout output_layout = config.output_layout; - bool must_swap_operands = - MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout); - if (must_swap_operands) { - std::swap(lhs_buffer, rhs_buffer); - } + TF_ASSIGN_OR_RETURN( + GemmConfig::DescriptorsTuple desc, + config.GetMatrixDescriptors(lhs_buffer, rhs_buffer, output_buffer)); - int64_t m = output_layout.num_rows; - int64_t n = output_layout.num_cols; - int64_t k = lhs_layout.num_cols; - MatrixDescriptor lhs = GetMatrixDesc(lhs_layout, lhs_buffer); - MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer); - MatrixDescriptor output = GetMatrixDesc(output_layout, output_buffer); - int64_t batch_size = output_layout.batch_size; - se::NumericOptions numeric_options{ - deterministic_ops, - /*allow_tf32=*/config.compute_precision <= 1}; + se::NumericOptions numeric_options(deterministic_ops); if (!algorithm) algorithm = config.algorithm; - se::blas::CallContext context = se::blas::CallContext::kNone; - if (config.grad_x) { - context = must_swap_operands - ? se::blas::CallContext::kBackpropInput2 - : se::blas::CallContext::kBackpropInput1; - } - if (config.grad_y) { - context = must_swap_operands - ? se::blas::CallContext::kBackpropInput1 - : se::blas::CallContext::kBackpropInput2; - } + se::blas::CallContext context = se::blas::CallContext::kNone; + std::tuple operand_types{config.lhs_layout.dtype, config.rhs_layout.dtype, + config.output_layout.dtype}; - std::tuple operand_types{ - lhs_layout.dtype, rhs_layout.dtype, output_layout.dtype}; + // Skip degenerate gemm with memzero. In general this is not safe, because it + // will suppress NaN propagation, however cuBLAS internally has exactly the + // same optimization for compatibility with NETLIB implementation, so we are + // not making things worse (and cuBLAS optimization is incompatible with CUDA + // graphs, so we are making sure we do not trigger it). + if (config.alpha.real() == 0.0 && config.alpha.imag() == 0.0 && + config.beta == 0.0) { + stream->ThenMemZero(&output_buffer, output_buffer.size()); + return OkStatus(); + } #define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ - if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) { \ - using NativeScaleType = \ - primitive_util::PrimitiveTypeToNative::type; \ - using NativeAType = primitive_util::PrimitiveTypeToNative::type; \ - using NativeCType = primitive_util::PrimitiveTypeToNative::type; \ + if (operand_types == std::tuple{DataType::ATYPE, DataType::BTYPE, DataType::CTYPE}) { \ + using NativeScaleType = DataType_t; \ + using NativeAType = DataType_t; \ + using NativeCType = DataType_t; \ return DoGemm( \ - batch_size, m, n, k, lhs, rhs, output, \ + desc.lhs, desc.rhs, desc.output, workspace_buffer, \ static_cast(config.alpha.real()), \ - static_cast(config.beta), stream, algorithm, \ - config.compute_precision, numeric_options, profile_result, \ - context); \ + static_cast(config.beta), stream, \ + algorithm, config.compute_precision, \ + numeric_options, profile_result, context); \ } #define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ - if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) { \ - using NativeScaleType = \ - primitive_util::PrimitiveTypeToNative::type; \ - using NativeAType = primitive_util::PrimitiveTypeToNative::type; \ - using NativeCType = primitive_util::PrimitiveTypeToNative::type; \ + if (operand_types == std::tuple(DataType::ATYPE, DataType::BTYPE, DataType::CTYPE)) { \ + using NativeScaleType = DataType_t; \ + using NativeAType = DataType_t; \ + using NativeCType = DataType_t; \ return DoGemm( \ - batch_size, m, n, k, lhs, rhs, output, \ + desc.lhs, desc.rhs, desc.output, workspace_buffer, \ static_cast(config.alpha), \ - static_cast(config.beta), stream, algorithm, \ - config.compute_precision, numeric_options, profile_result, \ - context); \ - } - - if (output_layout.dtype == S32) { - if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo; - return DoGemmWithAlgorithm( - batch_size, m, n, k, lhs, rhs, output, - static_cast(config.alpha.real()), - static_cast(config.beta), stream, *algorithm, - se::blas::kDefaultComputePrecision, numeric_options, profile_result, - context); - } - - TYPED_GEMM(F32, BF16, BF16, BF16) - TYPED_GEMM(F32, F16, F16, F16) - TYPED_GEMM(F32, S8, S8, F32) - TYPED_GEMM(F32, BF16, BF16, F32) - TYPED_GEMM(F32, F16, F16, F32) - TYPED_GEMM(F32, F32, F32, F32) - TYPED_GEMM(F64, F64, F64, F64) - TYPED_GEMM_COMPLEX(C64, C64, C64, C64) - TYPED_GEMM_COMPLEX(C128, C128, C128, C128) + static_cast(config.beta), stream, \ + algorithm, config.compute_precision, \ + numeric_options, profile_result, context); \ + } + + // if (config.output_layout.dtype == DataType::kInt32) { + // if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo; + // // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // // algorithm) here? Why do we use the default? + // return DoGemmWithAlgorithm( + // desc.lhs, desc.rhs, desc.output, workspace_buffer, + // static_cast(config.alpha.real()), + // static_cast(config.beta), stream, + // *algorithm, se::blas::kDefaultComputePrecision, numeric_options, + // profile_result, context); + // } + + TYPED_GEMM(kFloat, kBF16, kBF16, kBF16) + TYPED_GEMM(kFloat, kHalf, kHalf, kHalf) + // TYPED_GEMM(kFloat, kInt8, kInt8, kFloat) + // TYPED_GEMM(kFloat, kBF16, kBF16, kFloat) + // TYPED_GEMM(kFloat, kHalf, kHalf, kFloat) + TYPED_GEMM(kFloat, kFloat, kFloat, kFloat) + TYPED_GEMM(kDouble, kDouble, kDouble, kDouble) + TYPED_GEMM_COMPLEX(kComplexFloat, kComplexFloat, kComplexFloat, kComplexFloat) + TYPED_GEMM_COMPLEX(kComplexDouble, kComplexDouble, kComplexDouble, kComplexDouble) #undef TYPED_GEMM #undef TYPED_GEMM_COMPLEX - return InternalError( - "Unexpected GEMM dtype: %s %s %s", - primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(output_layout.dtype)); -} + return Internal( + "Unexpected GEMM dtype: %d %d %d", + (int)(config.lhs_layout.dtype), (int)(config.rhs_layout.dtype), + (int)(config.output_layout.dtype)); +} // namespace gpu -namespace cublas_lt { +namespace gpublas_lt { -StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue) { +StatusOr EpilogueAddsVectorBias( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: case GemmBackendConfig::RELU: @@ -788,11 +718,12 @@ StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue) { case GemmBackendConfig::BIAS_GELU_AUX: return true; default: - return InternalError("Unknown Epilogue."); + return Internal("Unknown BlasLt::Epilogue."); } } -StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue) { +StatusOr EpilogueHasAuxiliaryOutput( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: case GemmBackendConfig::RELU: @@ -805,336 +736,111 @@ StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue) { case GemmBackendConfig::BIAS_GELU_AUX: return true; default: - return InternalError("Unknown Epilogue."); + return Internal("Unknown BlasLt::Epilogue."); } } -} // namespace cublas_lt - -StatusOr AsBlasDataType(PrimitiveType dtype) { - switch (dtype) { - case F8E5M2: - return se::blas::DataType::kF8E5M2; - case F8E4M3FN: - return se::blas::DataType::kF8E4M3FN; - case S8: - return se::blas::DataType::kInt8; - case F16: - return se::blas::DataType::kHalf; - case BF16: - return se::blas::DataType::kBF16; - case F32: - return se::blas::DataType::kFloat; - case S32: - return se::blas::DataType::kInt32; - case F64: - return se::blas::DataType::kDouble; - case C64: - return se::blas::DataType::kComplexFloat; - case C128: - return se::blas::DataType::kComplexDouble; +StatusOr AsBlasLtEpilogue( + GemmBackendConfig_Epilogue epilogue) { + switch (epilogue) { + case GemmBackendConfig::DEFAULT: + return BlasLt::Epilogue::kDefault; + case GemmBackendConfig::RELU: + return BlasLt::Epilogue::kReLU; + case GemmBackendConfig::GELU: + return BlasLt::Epilogue::kGELU; + case GemmBackendConfig::GELU_AUX: + return BlasLt::Epilogue::kGELUWithAux; + case GemmBackendConfig::BIAS: + return BlasLt::Epilogue::kBias; + case GemmBackendConfig::BIAS_RELU: + return BlasLt::Epilogue::kBiasThenReLU; + case GemmBackendConfig::BIAS_GELU: + return BlasLt::Epilogue::kBiasThenGELU; + case GemmBackendConfig::BIAS_GELU_AUX: + return BlasLt::Epilogue::kBiasThenGELUWithAux; default: - return InternalError("AsBlasDataType: unsupported type: %s", - primitive_util::LowercasePrimitiveTypeName(dtype)); + return Internal("unexpected epilogue value"); } } -#if GOOGLE_CUDA || TF_HIPBLASLT - -namespace { - -StatusOr AsBlasLtMatrixLayout( - const MatrixLayout& layout) { - TF_ASSIGN_OR_RETURN(se::blas::DataType dtype, AsBlasDataType(layout.dtype)); - - auto order = (layout.order == MatrixLayout::Order::kColumnMajor) - ? se::gpu::BlasLt::MatrixLayout::Order::kColumnMajor - : se::gpu::BlasLt::MatrixLayout::Order::kRowMajor; - - return se::gpu::BlasLt::MatrixLayout::Create( - dtype, layout.num_rows, layout.num_cols, order, layout.batch_size, - layout.leading_dim_stride, layout.batch_stride); -} - -#if TF_HIPBLASLT -using cudaDataType_t = hipDataType; -#define CUDA_R_16BF HIP_R_16BF -#define CUDA_R_16F HIP_R_16F -#define CUDA_R_32F HIP_R_32F -#define CUDA_R_64F HIP_R_64F -#define CUDA_C_32F HIP_C_32F -#define CUDA_C_64F HIP_C_64F -#endif - -template -struct CudaToNativeT; - -#if CUDA_VERSION >= 11080 -template <> -struct CudaToNativeT { - using type = tsl::float8_e4m3fn; -}; -template <> -struct CudaToNativeT { - using type = tsl::float8_e5m2; -}; -#endif - -template <> -struct CudaToNativeT { - using type = Eigen::bfloat16; -}; -template <> -struct CudaToNativeT { - using type = Eigen::half; -}; -template <> -struct CudaToNativeT { - using type = float; -}; -template <> -struct CudaToNativeT { - using type = double; -}; -template <> -struct CudaToNativeT { - using type = complex64; -}; -template <> -struct CudaToNativeT { - using type = complex128; -}; - -} // namespace - -namespace cublas_lt { - -StatusOr AsBlasLtEpilogue( +StatusOr AsBlasLtEpilogue( mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue) { switch (epilogue) { case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default: - return se::gpu::BlasLt::Epilogue::kDefault; + return BlasLt::Epilogue::kDefault; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Relu: - return se::gpu::BlasLt::Epilogue::kReLU; + return BlasLt::Epilogue::kReLU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Gelu: - return se::gpu::BlasLt::Epilogue::kGELU; + return BlasLt::Epilogue::kGELU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux: - return se::gpu::BlasLt::Epilogue::kGELUWithAux; + return BlasLt::Epilogue::kGELUWithAux; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias: - return se::gpu::BlasLt::Epilogue::kBias; + return BlasLt::Epilogue::kBias; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu: - return se::gpu::BlasLt::Epilogue::kBiasThenReLU; + return BlasLt::Epilogue::kBiasThenReLU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu: - return se::gpu::BlasLt::Epilogue::kBiasThenGELU; + return BlasLt::Epilogue::kBiasThenGELU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux: - return se::gpu::BlasLt::Epilogue::kBiasThenGELUWithAux; + return BlasLt::Epilogue::kBiasThenGELUWithAux; } return InternalError("unexpected epilogue value"); } -/*static*/ StatusOr MatmulPlan::From( - const GemmConfig& config, se::gpu::BlasLt::Epilogue epilogue) { - MatrixLayout lhs_layout = config.lhs_layout; - MatrixLayout rhs_layout = config.rhs_layout; - MatrixLayout output_layout = config.output_layout; - MatrixLayout c_layout = config.c_layout; - - // cublasLt matmul requires batch sizes to be equal. If only one operand has a - // batch, the other will be broadcast (as its batch_stride == 0). - size_t batch_size = std::max(lhs_layout.batch_size, rhs_layout.batch_size); - lhs_layout.batch_size = batch_size; - rhs_layout.batch_size = batch_size; - - bool must_swap_operands = - MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout, c_layout); - - // Do not transopse either input. Note the cuBLASLt documentation somewhat - // incorrectly claims "A must be transposed and B non-transposed" when A and B - // are FP8 (https://docs.nvidia.com/cuda/cublas/#cublasltmatmul). In reality, - // this is only true if A and B are column-major. If A is row-major, A must - // *not* be transposed, and if B is row-major, B must be transposed. We never - // transpose A or B, and expect the caller to ensure A is row-major and B is - // column when A and B are FP8. - se::blas::Transpose trans_a = se::blas::Transpose::kNoTranspose; - se::blas::Transpose trans_b = se::blas::Transpose::kNoTranspose; - if (primitive_util::IsF8Type(lhs_layout.dtype) && - lhs_layout.order == MatrixLayout::Order::kColumnMajor) { - return InternalError("The F8 LHS must be column-major"); - } - if (primitive_util::IsF8Type(rhs_layout.dtype) && - rhs_layout.order == MatrixLayout::Order::kRowMajor) { - return InternalError("The F8 RHS must be row-major"); - } +} // namespace gpublas_lt - TF_ASSIGN_OR_RETURN(se::blas::DataType output_dtype, - AsBlasDataType(output_layout.dtype)); - TF_ASSIGN_OR_RETURN( - se::blas::ComputationType computation_type, - GetBlasComputationType(lhs_layout.dtype, output_layout.dtype, - config.compute_precision)); +StatusOr IsMatrixMultiplicationTooSmallForRewriting( + const HloInstruction& dot, int64_t threshold) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); -#if TENSORFLOW_USE_ROCM - if (lhs_layout.order == MatrixLayout::Order::kRowMajor) { - trans_a = se::blas::Transpose::kTranspose; - lhs_layout.Transpose(); - } - if (rhs_layout.order == MatrixLayout::Order::kRowMajor) { - trans_b = se::blas::Transpose::kTranspose; - rhs_layout.Transpose(); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers(); + + int64_t contracting_size = 1; + for (int64_t dim : dot_dims.lhs_contracting_dimensions()) { + contracting_size *= lhs_shape.dimensions(dim); } -#endif TF_ASSIGN_OR_RETURN( - se::gpu::BlasLt::MatmulDesc op_desc, - se::gpu::BlasLt::MatmulDesc::Create( - computation_type, GetScaleType(output_dtype, computation_type), - trans_a, trans_b, epilogue)); - - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::MatrixLayout a_desc, - AsBlasLtMatrixLayout(lhs_layout)); - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::MatrixLayout b_desc, - AsBlasLtMatrixLayout(rhs_layout)); - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::MatrixLayout c_desc, - AsBlasLtMatrixLayout(c_layout)); - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::MatrixLayout d_desc, - AsBlasLtMatrixLayout(output_layout)); - - return MatmulPlan{ - se::gpu::BlasLt::MatmulPlan{std::move(op_desc), std::move(a_desc), - std::move(b_desc), std::move(c_desc), - std::move(d_desc)}, - config.alpha, config.beta, must_swap_operands}; -} + std::vector lhs_non_contracting_dims, + GetNonContractingDims(lhs_shape, dot_dims.lhs_batch_dimensions(), + dot_dims.lhs_contracting_dimensions())); + int64_t lhs_non_contracting_size = 1; + for (int64_t dim : lhs_non_contracting_dims) { + lhs_non_contracting_size *= lhs_shape.dimensions(dim); + } -template -Status MatmulPlan::DoMatmul( - se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, se::DeviceMemoryBase bias_buffer, - se::DeviceMemoryBase aux_buffer, se::DeviceMemoryBase a_scale_buffer, - se::DeviceMemoryBase b_scale_buffer, se::DeviceMemoryBase c_scale_buffer, - se::DeviceMemoryBase d_scale_buffer, se::DeviceMemoryBase d_amax_buffer, - const se::gpu::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result) const { - se::gpu::BlasLt* blas_lt = se::gpu::GetBlasLt(stream); - TF_RET_CHECK(blas_lt != nullptr); - - Scale alpha; - if constexpr (std::is_same_v || - std::is_same_v) { - alpha = static_cast(alpha_); - } else { - alpha = static_cast(alpha_.real()); + TF_ASSIGN_OR_RETURN( + std::vector rhs_non_contracting_dims, + GetNonContractingDims(rhs_shape, dot_dims.rhs_batch_dimensions(), + dot_dims.rhs_contracting_dimensions())); + int64_t rhs_non_contracting_size = 1; + for (int64_t dim : rhs_non_contracting_dims) { + rhs_non_contracting_size *= rhs_shape.dimensions(dim); } - Scale beta = static_cast(beta_); - - se::DeviceMemory output(d_buffer); - return blas_lt->DoMatmul( - stream, plan_, se::HostOrDeviceScalar(alpha), - se::DeviceMemory(a_buffer), se::DeviceMemory(b_buffer), - se::HostOrDeviceScalar(beta), se::DeviceMemory(c_buffer), - output, algorithm, scratch_allocator, se::DeviceMemory(bias_buffer), - aux_buffer, se::DeviceMemory(a_scale_buffer), - se::DeviceMemory(b_scale_buffer), - se::DeviceMemory(c_scale_buffer), - se::DeviceMemory(d_scale_buffer), - se::DeviceMemory(d_amax_buffer), profile_result); + return (rhs_non_contracting_size + lhs_non_contracting_size) * + contracting_size < + threshold; } -Status MatmulPlan::ExecuteOnStream( - se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, se::DeviceMemoryBase bias_buffer, - se::DeviceMemoryBase aux_buffer, se::DeviceMemoryBase a_scale_buffer, - se::DeviceMemoryBase b_scale_buffer, se::DeviceMemoryBase c_scale_buffer, - se::DeviceMemoryBase d_scale_buffer, se::DeviceMemoryBase d_amax_buffer, - const se::gpu::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result) const { - if (must_swap_operands_) { - std::swap(a_buffer, b_buffer); - } +bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot) { + // if (!algorithm_util::IsSupportedByElementalIrEmitter( + // dot.precision_config().algorithm())) { + // return false; + // } - std::tuple - operand_types{plan_.a_desc.type(), plan_.b_desc.type(), - plan_.c_desc.type(), plan_.d_desc.type()}; - -#define TYPED_MATMUL(SCALENTYPE, ATYPE, BTYPE, CTYPE, DTYPE) \ - if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE, DTYPE)) { \ - return DoMatmul::type, \ - CudaToNativeT::type, CudaToNativeT::type, \ - CudaToNativeT::type>( \ - stream, a_buffer, b_buffer, c_buffer, d_buffer, bias_buffer, \ - aux_buffer, a_scale_buffer, b_scale_buffer, c_scale_buffer, \ - d_scale_buffer, d_amax_buffer, algorithm, scratch_allocator, \ - profile_result); \ + // Let us be conservative and only throw float dots at the emitters. + switch (dot.shape().element_type()) { + case F16: + case F32: + case BF16: + return true; + default: + return false; } - -#if CUDA_VERSION >= 11080 - // FP8 compatible type combinations (see cuBLASLt documentation): - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16BF, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16F, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_32F, CUDA_R_32F) - - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16BF, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16BF, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16F, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16F, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_32F, CUDA_R_32F) - - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16BF, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16BF, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_32F, CUDA_R_32F) -#endif - - // Other data types: - TYPED_MATMUL(float, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_32F, CUDA_R_32F) - TYPED_MATMUL(float, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F, CUDA_R_32F) - TYPED_MATMUL(float, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F) - TYPED_MATMUL(double, CUDA_R_64F, CUDA_R_64F, CUDA_R_64F, CUDA_R_64F) - TYPED_MATMUL(complex64, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F) - TYPED_MATMUL(complex128, CUDA_C_64F, CUDA_C_64F, CUDA_C_64F, CUDA_C_64F) - -#undef TYPED_MATMUL - - return InternalError("Unexpected dtype"); } -StatusOr> -MatmulPlan::GetAlgorithms(se::Stream* stream) const { - se::gpu::BlasLt* blas_lt = se::gpu::GetBlasLt(stream); - TF_RET_CHECK(blas_lt != nullptr); - TF_ASSIGN_OR_RETURN(auto preference, - se::gpu::BlasLt::MatmulPreference::Create( - /*max_workspace_size=*/1ll << 32)); // 4GB - return blas_lt->GetMatmulAlgorithms(plan_, preference); -} - -} // namespace cublas_lt - -#endif // GOOGLE_CUDA || TF_HIPBLASLT - } // namespace gpu -} // namespace xla +} // namespace xla \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 294123b48f1309..36e2029894ef29 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,42 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_MATMUL_UTILS_H_ -#define XLA_SERVICE_GPU_MATMUL_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ +#include #include #include +#include +#include #include #include #include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/shape.h" -#include "xla/statusor.h" -#include "xla/stream_executor/blas.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/scratch_allocator.h" - -#elif TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#if TF_HIPBLASLT -#include "xla/stream_executor/rocm/hip_blas_lt.h" -#include "xla/stream_executor/scratch_allocator.h" -#endif // TF_HIPBLASLT -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { +// Ordered non-contracting dimensions for a dot instruction operand. StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims); @@ -59,60 +50,74 @@ const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( const HloInstruction& dot, int operand_number); // Index of the only contracting dimension of dot instruction operand. -int64_t ContractingDimensionIndex(const HloInstruction& dot, - int operand_number); +StatusOr ContractingDimensionIndex(const HloInstruction& dot, + int operand_number); // Index of the only non-contracting dimension of dot instruction operand. -int64_t NonContractingDimensionIndex(const HloInstruction& dot, - int operand_number); +StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + int operand_number); // Normalize shape to (batch, rows, columns) logical dimensions. -StatusOr GetBatchRowColumnShape(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); +StatusOr GetBatchRowColumnShape( + const Shape& shape, absl::Span batch_dims, + absl::Span row_dims, absl::Span col_dims); -struct MatrixLayout { - enum class Order { - kRowMajor, // Elements in the same row are contiguous in memory. - kColumnMajor, // Elements in the same column are contiguous in memory. - }; +// GPU folding rule for the `TransposeFolding` pass. +StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, + int64_t operand_idx); + +// Returns true if the sum of the sizes of the unbatched operand matrices +// for the dot is smaller than the given threshold. +StatusOr IsMatrixMultiplicationTooSmallForRewriting( + const HloInstruction& dot, int64_t threshold); + +// Returns true if the backend can lower the dot. Currently the classical +// emitters cannot handle some dots, e.g., i8[] x i8[] -> i32[] dots, +// so we need to always use cuBLAS or Triton for those. +bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot); + +// extending plain MatrixLayout struct with creator functions +struct MatrixLayout : public se::gpu::MatrixLayout { + + MatrixLayout(se::blas::DataType dtype_, int64_t num_rows_, int64_t num_cols_, + Order order_, int64_t batch_size_ = 1, + absl::optional leading_dim_stride_ = {}, + absl::optional batch_stride_ = {}, + absl::optional transpose_ = {}) : + se::gpu::MatrixLayout(dtype_, num_rows_, num_cols_, + order_, batch_size_, leading_dim_stride_, + batch_stride_, transpose_) {} // Returns the matrix layout for a logical shape (batch, rows, columns). static StatusOr For(const Shape& shape); // Returns the matrix layout with the given batch, row, col dimensions. static StatusOr For(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); + absl::Span batch_dims, + absl::Span row_dims, + absl::Span col_dims); // Returns the matrix layout for the output. static StatusOr For(const Shape& shape, - size_t lhs_num_batch_dims, - size_t lhs_num_row_dims, - size_t rhs_num_batch_dims, - size_t rhs_num_col_dims); - - void Transpose(); - - PrimitiveType dtype; - // `num_rows` / `num_cols` are for the "logical" matrix shape: - // i.e. the contracting dim has size `num_cols` for LHS operands and - // `num_rows` for RHS operands. - int64_t num_rows; - int64_t num_cols; - Order order; - int64_t leading_dim_stride; - int64_t batch_size; - int64_t batch_stride; // `batch_stride` is set to `0` when `batch_size == 1`. + size_t lhs_num_batch_dims, + size_t lhs_num_row_dims, + size_t rhs_num_batch_dims, + size_t rhs_num_col_dims); }; -// GPU folding rule for the `TransposeFolding` pass. -StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, - int64_t operand_idx); +struct GemmConfig : public se::gpu::GemmConfig { + // For legacy Gemm operations XLA:GPU allocates its own workspace and passes + // it to all BLAS API calls. + // + // Size of the workspace based on NVIDIA recommendation: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + static constexpr int64_t kHopperWorkspace = 32 * 1024 * 1024; // 32 MiB + static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024; // 4 MiB + + explicit GemmConfig(const se::gpu::GemmConfig& cfg) : + se::gpu::GemmConfig(cfg) { } -struct GemmConfig { static StatusOr For(const HloInstruction* gemm); static StatusOr For(mlir::lmhlo_gpu::GEMMOp op); + static StatusOr For(mlir::lmhlo_gpu::CublasLtMatmulOp op); static StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, @@ -120,160 +125,47 @@ struct GemmConfig { absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision, - bool grad_x, bool grad_y); - - // As above with additional `c_shape` and `bias_shape_ptr` parameter, both - // which are only necessarily for F8 gemms. - static StatusOr For( - const Shape& lhs_shape, absl::Span lhs_batch_dims, - absl::Span lhs_contracting_dims, const Shape& rhs_shape, - absl::Span rhs_batch_dims, - absl::Span rhs_contracting_dims, const Shape& c_shape, - const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real, - double alpha_imag, double beta, std::optional algorithm, - int64_t compute_precision, bool grad_x, bool grad_y); - - template ::value || - std::is_same::value>> - static StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - Shape bias_shape; - if (op.getBias() != nullptr) { - bias_shape = GetShape(op.getBias()); - } - return GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getBias() == nullptr ? nullptr : &bias_shape, GetShape(op.getD()), - op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision, /*grad_x=*/false, /*grad_y=*/false); - } - - MatrixLayout lhs_layout; - MatrixLayout rhs_layout; - MatrixLayout c_layout; - MatrixLayout output_layout; - complex128 alpha; - double beta; - std::optional algorithm; - int64_t compute_precision; - bool grad_x, grad_y; + int64_t algorithm, int64_t compute_precision, + se::gpu::BlasLt::Epilogue epilogue); + + struct DescriptorsTuple { + se::gpu::MatrixDescriptor lhs; + se::gpu::MatrixDescriptor rhs; + se::gpu::OutputMatrixDescriptor output; + bool operands_swapped; + }; + StatusOr GetMatrixDescriptors( + se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, + se::DeviceMemoryBase out_buf) const; }; -StatusOr GetBlasComputationType( - PrimitiveType lhs_dtype, PrimitiveType output_dtype, - int64_t compute_precision); - -namespace cublas_lt { - -// Returns the type for the alpha and beta scalars. -se::blas::DataType GetScaleType(se::blas::DataType c_type, - se::blas::ComputationType computation_type); - -} // namespace cublas_lt - // Run the given GEMM instruction `gemm` subject to the configuration // in `gemm_config` and the passed buffers. // // If `algorithm` is provided, it overrides the one specified in `config`. -Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, bool deterministic_ops, - se::Stream* stream, - std::optional algorithm = std::nullopt, - se::blas::ProfileResult* profile_result = nullptr); - -namespace cublas_lt { - -StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue); -StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue); - -} // namespace cublas_lt - -StatusOr AsBlasDataType(PrimitiveType dtype); +Status RunGemm( + const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, + se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase workspace_buffer, bool deterministic_ops, + se::Stream* stream, + std::optional algorithm = std::nullopt, + se::blas::ProfileResult* profile_result = nullptr); -#if GOOGLE_CUDA || TF_HIPBLASLT +namespace gpublas_lt { -namespace cublas_lt { +StatusOr EpilogueAddsVectorBias( + GemmBackendConfig_Epilogue epilogue); +StatusOr EpilogueHasAuxiliaryOutput( + GemmBackendConfig_Epilogue epilogue); +StatusOr AsBlasLtEpilogue( + GemmBackendConfig_Epilogue epilogue); StatusOr AsBlasLtEpilogue( mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue); -class MatmulPlan { - public: - static StatusOr From(const GemmConfig& config, - se::gpu::BlasLt::Epilogue epilogue); - - Status ExecuteOnStream( - se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, - se::DeviceMemoryBase bias_buffer, // may be null - se::DeviceMemoryBase aux_buffer, // may be null - se::DeviceMemoryBase a_scale_buffer, se::DeviceMemoryBase b_scale_buffer, - se::DeviceMemoryBase c_scale_buffer, se::DeviceMemoryBase d_scale_buffer, - se::DeviceMemoryBase d_amax_buffer, - const se::gpu::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result = nullptr) const; - - StatusOr> GetAlgorithms( - se::Stream* stream) const; - - private: - MatmulPlan(se::gpu::BlasLt::MatmulPlan plan, complex128 alpha, double beta, - bool must_swap_operands) - : plan_(std::move(plan)), - alpha_(alpha), - beta_(beta), - must_swap_operands_(must_swap_operands) {} - - template - Status DoMatmul(se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, - se::DeviceMemoryBase bias_buffer, // may be null - se::DeviceMemoryBase aux_buffer, // may be null - se::DeviceMemoryBase a_scale, se::DeviceMemoryBase b_scale, - se::DeviceMemoryBase c_scale, se::DeviceMemoryBase d_scale, - se::DeviceMemoryBase d_amax, - const se::gpu::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result) const; - - se::gpu::BlasLt::MatmulPlan plan_; - complex128 alpha_; - double beta_; - bool must_swap_operands_; -}; - -} // namespace cublas_lt - -#endif // GOOGLE_CUDA || TF_HIPBLASLT +} // namespace gpublas_lt } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_MATMUL_UTILS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc b/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc index aabb1c7e843d6d..ebb663dbe318e2 100644 --- a/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc +++ b/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc @@ -75,7 +75,7 @@ void PopulateCublasLtMatmulAttrEncoding(CustomCallAttrEncodingSet& encoding) { return cublas_lt::AsBlasLtEpilogue(value).value(); }); } - +34f34f //===----------------------------------------------------------------------===// // cuBLASLt matmul custom call implementation. //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/gemm.cc b/third_party/xla/xla/service/gpu/runtime/gemm.cc index 29edea4828f39a..abaa435bde6883 100644 --- a/third_party/xla/xla/service/gpu/runtime/gemm.cc +++ b/third_party/xla/xla/service/gpu/runtime/gemm.cc @@ -92,8 +92,9 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, // we pass a non-null ProfileResult, DoGemmWithAlgorithm should // always return true, and the actual success-ness is returned in // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, deterministic_ops, - stream, algorithm, &profile_result)); + se::DeviceMemoryBase workspace{}; + TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, workspace, false, + stream, algorithm, &profile_result)); return std::move(profile_result); })); @@ -152,8 +153,9 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, #endif } - return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, - deterministic_ops, stream); + se::DeviceMemoryBase workspace{}; + return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, workspace, false, + stream); } static absl::Status InitCuBLASImpl( diff --git a/third_party/xla/xla/service/gpu/runtime/support.h b/third_party/xla/xla/service/gpu/runtime/support.h index a73e1a44870b50..3f4f6a8ceca38b 100644 --- a/third_party/xla/xla/service/gpu/runtime/support.h +++ b/third_party/xla/xla/service/gpu/runtime/support.h @@ -112,7 +112,7 @@ inline StatusOr GetGemmConfig( return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs), rhs_batch, rhs_contract, c_shape, bias_shape_ptr, ToShape(out), alpha_real, alpha_imag, beta, algorithm, - compute_precision, grad_x, grad_y); + compute_precision, se::gpu::BlasLt::Epilogue::kDefault); } // adds Dot Dimension Attribute encodings for calls to Gemm and cuBLASLt diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 7904e6d5e86f16..c4d858bcd6dc16 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -98,7 +98,107 @@ StatusOr DataLayoutToXlaLayout( return LayoutUtil::MakeLayoutFromMajorToMinor(layout); } -} // anonymous namespace +std::vector KeepNonFailures( + absl::Span profile_results) { + // Filter out all failures except WRONG_RESULT, because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). Other failures can be detected with high accuracy. E.g. + // REDZONE_MODIFIED which is also quite severe. + std::vector filtered_results; + absl::c_copy_if(profile_results, std::back_inserter(filtered_results), + [](const AutotuneResult& r) { + return !r.has_failure() || + r.failure().kind() == AutotuneResult::WRONG_RESULT; + }); + return filtered_results; +} + +Status AllAlgorithmsFailedInternalError( + absl::optional instr_str, + absl::Span profile_results) { + std::ostringstream msg; + if (instr_str.has_value()) { + msg << "All algorithms tried for " << instr_str.value() + << " failed. Falling back to default algorithm. Per-algorithm " + "errors:"; + } else { + msg << "All algorithms failed. Falling back to the default algorithm. " + << "Per-algorithm errors:"; + } + for (const auto& result : profile_results) { + msg << "\n " << result.failure().msg(); + } + return Internal("%s", msg.str()); +} + +Status NoAlgorithmSuppliedInternalError( + absl::optional instr_str) { + std::ostringstream msg; + if (instr_str.has_value()) { + msg << "There are no algorithm candidates for computing: \n " + << instr_str.value() + << "\nThis likely means that the instruction shape is not supported by " + "the target GPU library."; + } else { + msg << "There are no algorithm candidates for computing the instruction.\n" + "This likely means that the instruction shape is not supported by " + "the target GPU library."; + } + return Internal("%s", msg.str()); +} + +void SortAutotuningResultsByRunTime(std::vector& results) { + absl::c_sort(results, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return tsl::proto_utils::FromDurationProto(lhs.run_time()) < + tsl::proto_utils::FromDurationProto(rhs.run_time()); + }); +} + +absl::Span TopResultsWithinMeasurementError( + std::vector& results_sorted_by_runtime) { + // This value was picked by repeatedly running a few kernels that run for a + // short time and observing the run-time variance. A more rigorous analysis + // of the measurement error might yield a better error threshold. + constexpr absl::Duration kMeasurementError = absl::Microseconds(4); + + absl::Duration min_time = tsl::proto_utils::FromDurationProto( + results_sorted_by_runtime.front().run_time()); + absl::Duration limit_time = min_time + kMeasurementError; + + auto limit_time_it = absl::c_find_if( + results_sorted_by_runtime, [limit_time](const AutotuneResult& x) { + return tsl::proto_utils::FromDurationProto(x.run_time()) > limit_time; + }); + return absl::MakeSpan(&*results_sorted_by_runtime.begin(), &*limit_time_it); +} +} // anonymous namespace + +StatusOr PickBestResult( + absl::Span profile_results, + absl::optional instr_str) { + if (profile_results.empty()) { + return NoAlgorithmSuppliedInternalError(instr_str); + } + + std::vector filtered_results = + KeepNonFailures(profile_results); + + if (filtered_results.empty()) { + return AllAlgorithmsFailedInternalError(instr_str, profile_results); + } + + // Kernel run-time measurements within kMeasurementError are not precise. + // Consider the lowest measurements within the error margin as equivalent and + // within them prefer algorithms that use the least amount of scratch memory. + SortAutotuningResultsByRunTime(filtered_results); + auto top_within_error = TopResultsWithinMeasurementError(filtered_results); + return *absl::c_min_element(top_within_error, [](const AutotuneResult& lhs, + const AutotuneResult& rhs) { + return lhs.scratch_bytes() < rhs.scratch_bytes(); + }); +} + StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, @@ -365,7 +465,6 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, *kernel_args); } - // Unimplemented for integers yet. template typename std::enable_if::value, @@ -547,50 +646,5 @@ bool RequireDeterminism(const HloModuleConfig& config) { config.debug_options().xla_gpu_deterministic_ops(); } -StatusOr PickBestResult( - absl::Span profile_results, - std::optional instr_str, - HloModuleConfig hlo_module_config) { - std::vector filtered_results; - - // For now, we ignore WRONG_RESULT failures because false-positives are - // possible (e.g. perhaps the reference algorithm is the one that's - // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're - // quite severe and can be detected with high accuracy. - absl::c_copy_if( - profile_results, std::back_inserter(filtered_results), - [](const AutotuneResult& r) { - return !(r.has_failure() && - r.failure().kind() != AutotuneResult::WRONG_RESULT); - }); - - if (filtered_results.empty()) { - std::ostringstream msg; - if (instr_str.has_value()) { - msg << "All algorithms tried for " << instr_str.value() - << " failed. Falling back to default algorithm. Per-algorithm " - "errors:"; - } else { - msg << "All algorithms failed. Falling back to the default algorithm. " - << "Per-algorithm errors:"; - } - for (const auto& result : profile_results) { - msg << "\n " << result.failure().msg(); - } - return InternalError("%s", msg.str()); - } - - auto selected_result = filtered_results.begin(); - if (!RequireDeterminism(hlo_module_config)) { - selected_result = absl::c_min_element( - filtered_results, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return tsl::proto_utils::FromDurationProto(lhs.run_time()) < - tsl::proto_utils::FromDurationProto(rhs.run_time()); - }); - } - return *selected_result; -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index 4e291b5e9a04cc..714ae17359c222 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -111,8 +111,7 @@ StatusOr GetDNNDataTypeFromPrimitiveType(PrimitiveType type); // If deterministic output is requested, returns first (not failing) result. StatusOr PickBestResult( absl::Span profile_results, - std::optional instr_str, - HloModuleConfig hlo_module_config); + std::optional instr_str); // Returns whether determinism is required. bool RequireDeterminism(const HloModuleConfig& config); diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 4691c4a7270c12..b0d1f5df2cb587 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -166,6 +166,22 @@ xla_cc_test( ]), ) +xla_cc_test( + name = "gpu_hlo_runner_test", + srcs = ["gpu_hlo_runner_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//xla/service:gpu_plugin", + "//xla/service:hlo_parser", + "//xla/tests:filecheck", + "//xla:error_spec", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + + xla_cc_test( name = "gemm_broadcast_folding_rewrite_test", srcs = [ diff --git a/third_party/xla/xla/service/gpu/tests/gpu_hlo_runner_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_hlo_runner_test.cc new file mode 100644 index 00000000000000..a92b9b35c10b42 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/gpu_hlo_runner_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "xla/error_spec.h" +#include "xla/literal_comparison.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_utils.h" + +namespace xla { +namespace gpu { + +template +std::vector MakePointerVector(std::vector& input_vec) { + std::vector output_pointers; + output_pointers.reserve(input_vec.size()); + for (auto& input : input_vec) { + output_pointers.push_back(&input); + } + return output_pointers; +} + + +class HloRunnerTest : public GpuCodegenTest {}; + +TEST_F(HloRunnerTest, RunSingle) { + + std::ifstream ifs("input.hlo"); + ASSERT_TRUE(ifs.good()); + + std::stringstream buffer; + buffer << ifs.rdbuf(); + + HloModuleConfig config = GetModuleConfigForTest(); +#if 1 + //config.set_num_partitions(8); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(buffer.str(), + config)); + + auto ref_module = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(auto exec, test_runner_.CreateExecutable(std::move(module), true)); + + VLOG(0) << "Creating fake args.."; + TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, xla::MakeFakeArguments(ref_module.get(), + true, /*pseudo-random*/ + false /* use large range*/)); + auto arg_ptrs = MakePointerVector(fake_arguments); + + auto& ref_runner = HloTestBase::reference_runner_; + TF_ASSERT_OK_AND_ASSIGN( + auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true)); + + // TF_ASSERT_OK_AND_ASSIGN(auto truth, + // ReadLiteralFromProto("/tf/xla/expected.pb")); + // TF_ASSERT_OK_AND_ASSIGN(auto truth, + // ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs, nullptr)); + // WriteLiteralToTempFile(truth, "expected"); + //VLOG(0) << "Got expected literal from file.. running test"; + + TF_ASSERT_OK_AND_ASSIGN( + auto test_res, test_runner_.ExecuteWithExecutable(exec.get(), arg_ptrs)); + + VLOG(0) << "Running reference exec.."; + TF_ASSERT_OK_AND_ASSIGN( + auto truth, ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs)); + + ErrorSpec error_spec{1e-2, 1e-3}; + //ErrorSpec error_spec(1e-5 /*abs*/, 1e-5 /*rel*/); + ASSERT_EQ(literal_comparison::Near(/*expected=*/truth, + /*actual=*/test_res, + /*error=*/error_spec, + /*detailed_message=*/true, {}), OkStatus()); + + // EXPECT_TRUE(RunAndCompare(std::move(module), + // // absl::Span< xla::Literal * const>(arg_ptrs.data(), arg_ptrs.size()), error_spec)); +#else + int NumReplicas = 8, NumParts = 1; + config.set_replica_count(NumReplicas); + config.set_num_partitions(NumParts); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(buffer.str(), config)); + DeviceAssignment assn(/*replica_count=*/NumReplicas, + /*computation_count=*/NumParts); + for (int64_t i = 0, k = 0; i < NumReplicas; i++) + for (int64_t j = 0; j < NumParts; j++) { + assn(i, j) = k++; + } + + auto fake_arguments = xla::MakeFakeArguments( + module.get(), + true, /*pseudo-random*/ + false /* use large range*/).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(auto exec, + test_runner_.CreateExecutable(std::move(module), true)); + + for(int i = 0; i < 10; i++) { + VLOG(0) << "Running iteration #" << i; + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + HloTestBase::ExecuteReplicated( + [&](int64_t){ return exec.get(); }, + [&fake_arguments](int64_t replica_id) + { return fake_arguments.size(); }, + [&fake_arguments](int64_t replica_id, int64_t idx) + { return &fake_arguments[idx]; }, + NumReplicas, false /*run hlo*/, &assn)); + ASSERT_EQ(results.size(), NumReplicas); + } +#endif +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/stream_executor/blas.cc b/third_party/xla/xla/stream_executor/blas.cc index efcda8daa46970..8aba9e0c509e37 100644 --- a/third_party/xla/xla/stream_executor/blas.cc +++ b/third_party/xla/xla/stream_executor/blas.cc @@ -22,6 +22,27 @@ limitations under the License. namespace stream_executor { namespace blas { +// TODO(ezhulenev): We need a scoped thread local map-like container to make +// sure that we can have multiple BlasSupport instances that do not overwrite +// each others workspaces. For not it's ok as we know that this can't happen. +static thread_local DeviceMemoryBase* workspace_thread_local = nullptr; + +BlasSupport::ScopedWorkspace::ScopedWorkspace(BlasSupport* blas, + DeviceMemoryBase* workspace) + : blas_(blas) { + blas->SetWorkspace(workspace); +} + +BlasSupport::ScopedWorkspace::~ScopedWorkspace() { blas_->ResetWorkspace(); } + +DeviceMemoryBase* BlasSupport::GetWorkspace() { return workspace_thread_local; } + +void BlasSupport::SetWorkspace(DeviceMemoryBase* workspace) { + workspace_thread_local = workspace; +} + +void BlasSupport::ResetWorkspace() { workspace_thread_local = nullptr; } + std::string TransposeString(Transpose t) { switch (t) { case Transpose::kNoTranspose: diff --git a/third_party/xla/xla/stream_executor/blas.h b/third_party/xla/xla/stream_executor/blas.h index 3ccbfc6b7506fb..49e71e90feadd0 100644 --- a/third_party/xla/xla/stream_executor/blas.h +++ b/third_party/xla/xla/stream_executor/blas.h @@ -59,6 +59,10 @@ struct half; namespace stream_executor { +namespace gpu { +struct BlasLt; +} // namespace gpu + class Stream; class ScratchAllocator; @@ -211,6 +215,7 @@ constexpr ComputePrecision kDefaultComputePrecision = 0; class BlasSupport { public: virtual ~BlasSupport() {} + virtual gpu::BlasLt *GetBlasLt() = 0; // Performs a BLAS y <- ax+y operation. virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, @@ -505,10 +510,41 @@ class BlasSupport { virtual tsl::Status GetVersion(std::string *version) = 0; + // TODO(ezhulenev): We should never pass ScratchAllocator to any of the APIs + // in this file, because it makes them incompatible with command buffers (CUDA + // graphs). We should pass workspace memory explicitly to all APIs. However + // this is a giant change, so currently we work around it by setting a thread + // local workspace and rely on `ScopedBlasWorkspace` RAII helper to reset it. + // + // APIs that get ScratchAllocator ignore this workspace, and continue + // allocating scratch memory on demand. + class ScopedWorkspace { + public: + ScopedWorkspace(BlasSupport *blas, DeviceMemoryBase *workspace); + ~ScopedWorkspace(); + + private: + BlasSupport *blas_; + }; + protected: + DeviceMemoryBase *GetWorkspace(); BlasSupport() {} private: + // Workspace memory pointer is thread local, once it is set all Blas + // operations issued from a caller thread might use it if it has large enough + // size. It's a user responsibility to make sure that workspace will outlive + // all issued BLAS operations. + // + // TODO(ezhulenev): This is a giant footgun! We have to remove it and use + // explicit workspace memory argument for all BLAS operations. + void SetWorkspace(DeviceMemoryBase *workspace); + + // Resets user-defined workspace memory, so that Blas operations can use their + // own memory pool for allocating workspace. + void ResetWorkspace(); + BlasSupport(const BlasSupport &) = delete; void operator=(const BlasSupport &) = delete; }; diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index acc52feca11755..b4fff2f2b0f235 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -1,13 +1,16 @@ # Description: # GPU-platform specific StreamExecutor support code. - +load( + "//tensorflow:tensorflow.bzl", + "tf_gpu_kernel_library", +) load( "//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured", ) load( "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", + "if_rocm_is_configured", "rocm_copts" ) load( "@local_tsl//tsl:tsl.bzl", @@ -56,6 +59,39 @@ cc_library( ]), ) +cc_library( + name = "gpu_blas_lt", + srcs = if_gpu_is_configured(["gpu_blas_lt.cc"]), + hdrs = if_gpu_is_configured(["gpu_blas_lt.h"]), + deps = if_gpu_is_configured([ + "//xla/stream_executor", + "//xla:xla_data_proto_cc", + "//xla:statusor", + "//xla:util", + "//xla:shape_util", + #"//tensorflow/core/platform:env", + "@local_tsl//tsl/util:env_var", + "@com_google_absl//absl/types:any", + ]), +) + + +cc_library( + name = "gpu_blas_lt_gemm_runner", + srcs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.cc"]), + hdrs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.h"]), + deps = if_gpu_is_configured([ + "//tensorflow/core/protobuf:autotuning_proto_cc", + "//xla:autotune_results_proto_cc", + "//xla:xla_proto_cc", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:scratch_allocator", + "//xla/service/gpu:autotuner_util", + "//xla:debug_options_flags", + ":gpu_blas_lt", + ]), +) + cc_library( name = "gpu_diagnostics_header", hdrs = if_gpu_is_configured(["gpu_diagnostics.h"]), diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD.rej b/third_party/xla/xla/stream_executor/gpu/BUILD.rej new file mode 100644 index 00000000000000..c80fd9b5329024 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/BUILD.rej @@ -0,0 +1,92 @@ +diff a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD (rejected hunks) +@@ -174,7 +210,7 @@ tsl_gpu_library( + "//tensorflow/compiler/tf2xla:__subpackages__", + "//xla:__subpackages__", + "//tensorflow/core/common_runtime/gpu:__subpackages__", +- "//tensorflow/stream_executor:__subpackages__", +// "//xla/stream_executor:__subpackages__", + ]), + deps = [ + "//xla/stream_executor:multi_platform_manager", +@@ -362,33 +398,65 @@ cc_library( + ]) + ["@local_tsl//tsl/platform:statusor"], + ) + +//# cc_library( +//# name = "redzone_allocator", +//# srcs = if_gpu_is_configured(["redzone_allocator.cc"]), +//# hdrs = if_gpu_is_configured(["redzone_allocator.h"]), +//# copts = tsl_copts(), +//# visibility = set_external_visibility([ +//# "//xla/service/gpu:__subpackages__", +//# "//xla/stream_executor:__subpackages__", +//# "//tensorflow/core/kernels:__subpackages__", +//# ]), +//# deps = if_gpu_is_configured([ +//# ":asm_compiler", +//# ":gpu_asm_opts", +//# "@com_google_absl//absl/base", +//# "@com_google_absl//absl/container:fixed_array", +//# "@com_google_absl//absl/status", +//# "@com_google_absl//absl/strings:str_format", +//# "@com_google_absl//absl/types:optional", +//# "@local_tsl//tsl/lib/math:math_util", +//# "@local_tsl//tsl/platform:errors", +//# "@local_tsl//tsl/platform:logging", +//# "@local_tsl//tsl/framework:allocator", +//# "//xla/stream_executor:device_memory", +//# "//xla/stream_executor:device_memory_allocator", +//# "//xla/stream_executor:scratch_allocator", +//# "//xla/stream_executor:stream_executor_headers", +//# "@local_tsl//tsl/platform:status", +//# ]), +//# ) +// +//tf_gpu_kernel_library( +// name = "redzone_allocator_kernel", +// hdrs = if_gpu_is_configured(["redzone_allocator_kernel.h"]), +// srcs = if_cuda_is_configured(["redzone_allocator_kernel_cuda.cc"]) + +// if_rocm_is_configured(["redzone_allocator_kernel_rocm.cu.cc"]), +// deps = [":gpu_asm_opts"], +//) +// + cc_library( + name = "redzone_allocator", +- srcs = if_gpu_is_configured(["redzone_allocator.cc"]), + hdrs = if_gpu_is_configured(["redzone_allocator.h"]), +- copts = tsl_copts(), +- visibility = set_external_visibility([ +- "//xla/service/gpu:__subpackages__", +- "//xla/stream_executor:__subpackages__", +- "//tensorflow/core/kernels:__subpackages__", +// srcs = if_gpu_is_configured(["redzone_allocator.cc"]), +// copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ +// "-DTENSORFLOW_USE_ROCM=1", + ]), +- deps = if_gpu_is_configured([ +- ":asm_compiler", +- ":gpu_asm_opts", +- "@com_google_absl//absl/base", +// visibility = ["//visibility:public"], +// deps = [ + "@com_google_absl//absl/container:fixed_array", +- "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", +- "@local_tsl//tsl/lib/math:math_util", +- "@local_tsl//tsl/platform:errors", +- "@local_tsl//tsl/platform:logging", +- "@local_tsl//tsl/framework:allocator", +// "@com_google_absl//absl/strings", +// ":redzone_allocator_kernel", +// ":gpu_asm_opts", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream_executor_headers", +- "@local_tsl//tsl/platform:status", +// ] + if_cuda_is_configured([ +// "//tensorflow/stream_executor/cuda:ptxas_utils", + ]), + ) + diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc new file mode 100644 index 00000000000000..49f0f96bae1aa4 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -0,0 +1,285 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_blas_lt.h" + +#include +#include +#include +#include + +#include "xla/primitive_util.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" +#include "tensorflow/tsl/util/env_var.h" + +namespace stream_executor { + +namespace gpu { + +using blas::ComputationType; +using blas::DataType; +using xla::PrimitiveType; + +bool GpuBlasLtEnabled() { + static std::atomic_bool result{[] { + bool value = false; + tsl::ReadBoolFromEnvVar("TF_ENABLE_GPU_BLASLT", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + +namespace { + +bool TF32_Enabled() { + static std::atomic_bool result{[] { + bool value = false; + (void)tsl::ReadBoolFromEnvVar("ROCM_XF32", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + +bool Fast_16F_Enabled() { + static std::atomic_bool result{[] { + bool value = false; + (void)tsl::ReadBoolFromEnvVar("ROCM_FAST_16F", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + +} // namespace + +xla::StatusOr AsBlasDataType(PrimitiveType dtype) { + switch (dtype) { + case PrimitiveType::S8: + return DataType::kInt8; + case PrimitiveType::F16: + return DataType::kHalf; + case PrimitiveType::BF16: + return DataType::kBF16; + case PrimitiveType::F32: + return DataType::kFloat; + case PrimitiveType::S32: + return DataType::kInt32; + case PrimitiveType::F64: + return DataType::kDouble; + case PrimitiveType::C64: + return DataType::kComplexFloat; + case PrimitiveType::C128: + return DataType::kComplexDouble; + default: + return xla::InternalError( + "AsBlasDataType: unsupported type: %s", + xla::primitive_util::LowercasePrimitiveTypeName(dtype)); + } +} + +xla::StatusOr GetBlasComputationType( + DataType lhs_dtype, DataType output_dtype, int64_t /*compute_precision*/) { + + auto f16_comp = Fast_16F_Enabled() ? + ComputationType::kF16AsF32 : ComputationType::kF32, + bf16_comp = Fast_16F_Enabled() ? + ComputationType::kBF16AsF32 : ComputationType::kF32; + + switch (output_dtype) { + case DataType::kHalf: // fall-through + return f16_comp; + case DataType::kBF16: + return bf16_comp; + case DataType::kFloat: // fall-through + if (lhs_dtype == DataType::kHalf) return f16_comp; + if (lhs_dtype == DataType::kBF16) return bf16_comp; + return TF32_Enabled() ? ComputationType::kTF32AsF32 + : ComputationType::kF32; + case DataType::kComplexFloat: + return ComputationType::kF32; + case DataType::kDouble: // fall-through + case DataType::kComplexDouble: + return ComputationType::kF64; + case DataType::kInt32: + return ComputationType::kI32; + default: + return xla::InternalError("GetBlasComputationType: unsupported type"); + } +} + +MatrixLayout::MatrixLayout(blas::DataType dtype_, int64_t num_rows_, + int64_t num_cols_, MatrixLayout::Order order_, + int64_t batch_size_, + absl::optional leading_dim_stride_, + absl::optional batch_stride_, + absl::optional transpose_) + : dtype(dtype_), + num_rows(num_rows_), + num_cols(num_cols_), + order(order_), + batch_size(batch_size_) { + if (!leading_dim_stride_) { + leading_dim_stride = order == Order::kRowMajor ? num_cols : num_rows; + } else { + leading_dim_stride = *leading_dim_stride_; + } + if (!batch_stride_) { + batch_stride = (batch_size > 1) ? num_rows * num_cols : 0; + } else { + batch_stride = *batch_stride_; + } + transpose = transpose_ ? *transpose_ : blas::Transpose::kNoTranspose; +} + +void MatrixLayout::Transpose() { + std::swap(num_rows, num_cols); + order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor; +} + +// BLAS GeMM's output is column-major. If we require row-major, use identity: +// C^T = (A @ B)^T = B^T @ A^T. +bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, + MatrixLayout& output, MatrixLayout* c) { + bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; + if (swap_operands) { + std::swap(lhs, rhs); + rhs.Transpose(); + // prevent layouts from being swapped two times if they are equal + if (&lhs != &rhs) { + lhs.Transpose(); + } + if (c != nullptr && c != &output) { + c->Transpose(); + } + output.Transpose(); + } + return swap_operands; +} + +/*static*/ auto BlasLt::GetMatmulPlan(const Stream* stream, + const GemmConfig& cfg) + -> xla::StatusOr { + auto blas = Get(stream); + if (blas == nullptr) { + return xla::InternalError("BlasLt is unavailable"); + } + return blas->GetMatmulPlan(cfg); +} + +/* static */ auto BlasLt::CreateGroupedMatmulPlan(Stream* stream, + const GroupedGemmConfig& cfg) -> xla::StatusOr { + auto blas = Get(stream); + if (blas == nullptr) { + return xla::InternalError("BlasLt is unavailable"); + } + return blas->GetGroupedMatmulPlan(stream, cfg); +} + +/*static*/ BlasLt* BlasLt::Get(const Stream* stream) { + auto blas = stream->parent()->AsBlas(); + return (blas != nullptr ? blas->GetBlasLt() : nullptr); +} + +DataType GetScaleType(DataType c_type, ComputationType compute_type) { + if (compute_type == ComputationType::kF32 && + c_type != DataType::kComplexFloat) { + return DataType::kFloat; + } + if (compute_type == ComputationType::kF16) return DataType::kFloat; + return c_type; +} + + +namespace { + +const std::vector TransposeNames = { + "N", // kNoTranspose + "T", // kTranspose + "C", // kConjugateTranspose +}; + +xla::StatusOr Transpose2String(blas::Transpose type) { + size_t idx = static_cast< size_t >(type); + if (idx < TransposeNames.size()) return TransposeNames[idx]; + return xla::InternalError("Unknown transpose type!"); +} + +xla::StatusOr String2Transpose(absl::string_view s) { + for(size_t i = 0; i < TransposeNames.size(); i++) { + if (s == TransposeNames[i]) return static_cast< blas::Transpose >(i); + } + return xla::InternalError("Unknown tranpose type!"); +} + +const std::vector TypeNames = { + "f32_r", //kFloat = 0, + "f64_r", //kDouble = 1, + "f16_r", //kHalf = 2, + "i8_r", //kInt8 = 3, + "i32_r", //kInt32 = 4, + "f32_c", //kComplexFloat = 5, + "f64_c", //kComplexDouble = 6, + "bf16_r", //kBF16 = 7, +}; + +xla::StatusOr Type2String(blas::DataType type) { + size_t idx = static_cast< size_t >(type); + if (idx < TypeNames.size()) return TypeNames[idx]; + return xla::InternalError("Unknown data type!"); +} + +} // namespace + +std::string ToCSVString(const GemmConfig& cfg, bool full_string) { + + ///constexpr char kCsvComment = '#'; + constexpr char kCsvSep = ','; + + const auto& L = cfg.lhs_layout, &R = cfg.rhs_layout, &O = cfg.output_layout; + + std::ostringstream oss; + auto type_a = Type2String(L.dtype).value(), + type_b = Type2String(R.dtype).value(), + type_c = Type2String(O.dtype).value(), + trans_a = Transpose2String(L.transpose).value(), + trans_b = Transpose2String(R.transpose).value(); + +// LHS: k x n +// RHS: m x k +// OUT: m x n + // VLOG(0) << "LHS: " << L.num_cols << "x" << L.num_rows; + // VLOG(0) << "RHS: " << R.num_cols << "x" << R.num_rows; + // VLOG(0) << "OUT: " << O.num_cols << "x" << O.num_rows; + int n = L.num_rows, k = L.num_cols, m = O.num_cols; + oss << m << kCsvSep << n << kCsvSep << k << kCsvSep + << O.batch_size << kCsvSep << trans_a << kCsvSep + << trans_b << kCsvSep << type_a << kCsvSep + << type_b << kCsvSep << type_c << kCsvSep << L.leading_dim_stride + << kCsvSep << R.leading_dim_stride << kCsvSep + << O.leading_dim_stride << kCsvSep << L.batch_stride << kCsvSep + << R.batch_stride << kCsvSep << O.batch_stride; + + if (full_string) { + // NOTE: epilogue is required for MatmulPlan caching ! + oss << kCsvSep << cfg.alpha.real() << kCsvSep << cfg.alpha.imag() << kCsvSep << cfg.beta << kCsvSep << (int64_t)cfg.epilogue; + } + + return oss.str(); +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h new file mode 100644 index 00000000000000..722ddc3717fca3 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -0,0 +1,278 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/any.h" +#include "xla/statusor.h" +#include "xla/types.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host_or_device_scalar.h" +#include "xla/xla_data.pb.h" + +namespace stream_executor { + +namespace gpu { + +bool GpuBlasLtEnabled(); + +xla::StatusOr AsBlasDataType(xla::PrimitiveType dtype); + +xla::StatusOr GetBlasComputationType( + blas::DataType lhs_dtype, blas::DataType output_dtype, + int64_t compute_precision); + +// Returns the type for the alpha and beta scalars. +blas::DataType GetScaleType(blas::DataType c_type, + blas::ComputationType computation_type); + +struct MatrixLayout { // plain MatrixLayout which is extended with create + // functions in matmul_utils.h + enum class Order { + kRowMajor, // Elements in the same row are contiguous in memory. + kColumnMajor, // Elements in the same column are contiguous in memory. + }; + + MatrixLayout() = default; + + MatrixLayout(blas::DataType dtype_, int64_t num_rows_, int64_t num_cols_, + Order order_, int64_t batch_size_ = 1, + absl::optional leading_dim_stride_ = {}, + absl::optional batch_stride_ = {}, + absl::optional transpose_ = {}); + + void Transpose(); + + blas::DataType dtype; + // `num_rows` / `num_cols` are for the "logical" matrix shape: + // i.e. the contracting dim has size `num_cols` for LHS operands and + // `num_rows` for RHS operands. + int64_t num_rows; + int64_t num_cols; + Order order; + int64_t batch_size; + int64_t leading_dim_stride; + // `batch_stride` is set to `0` when `batch_size == 1`. + int64_t batch_stride; + blas::Transpose transpose; +}; + +// compact version of the matrix layout to be used to pass matrices +// to underlying blas API +struct MatrixDescriptor { + DeviceMemoryBase data; + int64_t leading_dim_stride = 0; + int64_t batch_stride = 0; + blas::DataType type{}; + blas::Transpose transpose{}; + + template + DeviceMemory cast() const { + return DeviceMemory(data); + } +}; + +struct OutputMatrixDescriptor : public MatrixDescriptor { + OutputMatrixDescriptor(MatrixDescriptor&& parent) noexcept + : MatrixDescriptor(std::move(parent)) {} + int64_t batch_size = 0; + int64_t m = 0, n = 0, k = 0; + blas::ComputationType compute_type{}; +}; + +// BLAS GeMM's output is column-major. If we require row-major, use identity: +// C^T = (A @ B)^T = B^T @ A^T. +bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, + MatrixLayout& output, MatrixLayout* c = nullptr); + +struct GemmConfig; + +struct GroupedGemmConfig { + int64_t m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + const void *alpha, *beta; + blas::DataType type_a, type_b, type_c, type_d; + int64_t lda, ldb, ldc, ldd; + blas::ComputationType compute_type; + const void **a, **b, **c; + void **d; +}; + +struct BlasLt { + + static constexpr int64_t kMaxAlgorithms = 128; + + enum class Epilogue { + kDefault = 1, // No special postprocessing + kReLU = 2, // Apply point-wise ReLU function + kBias = 4, // Add broadcasted bias vector + kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform + kGELU = 32, // Apply GELU point-wise transform to the results + kGELUWithAux = 32 | 1024, // Apply GELU with auxiliary output. + kBiasThenGELU = kBias | kGELU, // Apply bias and then approximate GELU. + kBiasThenGELUWithAux = kBiasThenGELU | 1024, + }; + + // Describes the location of pointers for the scaling factors alpha and beta. + enum class PointerMode { + kHost, + kDevice, + }; + + struct MatmulAlgorithm { + absl::any opaque_algo; + size_t workspace_size; + blas::AlgorithmType id; + }; + + struct MatmulPlan { + // DoMatmul provides two sets of API for maintaning compatibility for XLA, + // and TF. One set API uses scratch_allocator to allocate workspace, and one + // set API allow uses to provide pre-allocated buffer as workspace. + + // Returns a list of supported algorithms for DoMatmul. The algorithms are + // returned in the order of increasing estimated compute time according to + // an internal heuristic. + virtual xla::StatusOr> GetAlgorithms( + size_t max_algorithm_count = kMaxAlgorithms, + size_t max_workspace_size = 1ll << 32) const = 0; + + // Algorithm needs to be set before calling ExecuteOnStream function + virtual xla::Status SetAlgorithm(const MatmulAlgorithm& algorithm) = 0; + + // The most general form: to be implemented by derived clases. + virtual xla::Status ExecuteOnStream( + Stream* stream, DeviceMemoryBase a_buffer, DeviceMemoryBase b_buffer, + DeviceMemoryBase c_buffer, DeviceMemoryBase d_buffer, + DeviceMemoryBase bias_buffer, // may be null + DeviceMemoryBase aux_buffer, // may be null + DeviceMemoryBase a_scale_buffer, DeviceMemoryBase b_scale_buffer, + DeviceMemoryBase c_scale_buffer, DeviceMemoryBase d_scale_buffer, + DeviceMemoryBase d_amax_buffer, + absl::optional workspace, + absl::optional scratch_allocator = absl::nullopt, + blas::ProfileResult* profile_result = nullptr) const = 0; + + virtual ~MatmulPlan() {} + + protected: + // might be used internally by ExecuteOnStream in derived classes + template + xla::Status DoMatmul(Stream* stream, xla::complex128 alpha, + DeviceMemoryBase a, DeviceMemoryBase b, double beta, + DeviceMemoryBase c, DeviceMemoryBase d, + DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result = nullptr) const { + Scale salpha; + if constexpr(std::is_same::value || + std::is_same::value) { + salpha = static_cast(alpha); + } else { + salpha = static_cast(alpha.real()); + } + Scale sbeta = static_cast(beta); + return DoMatmul(stream, &salpha, a, b, &sbeta, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, + d_amax, workspace, scratch_allocator, profile_result); + } + + // The most general version to be implemented by derived classes + virtual xla::Status DoMatmul( + Stream* stream, const void* alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, + DeviceMemoryBase d, DeviceMemoryBase bias, + DeviceMemoryBase aux, DeviceMemoryBase a_scale, + DeviceMemoryBase b_scale, DeviceMemoryBase c_scale, + DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result = nullptr) const = 0; + }; // class MatmulPlan + + struct GroupedMatmulPlan { + + virtual xla::StatusOr> GetAlgorithms( + size_t max_algorithm_count = kMaxAlgorithms, + size_t max_workspace_size = 1ll << 32) = 0; + + virtual xla::Status SetAlgorithm(const MatmulAlgorithm& algorithm, + ScratchAllocator * scratch_allocator) = 0; + + virtual xla::Status ExecuteOnStream(Stream *stream, + const gpu::GroupedGemmConfig& cfg, + blas::ProfileResult* profile_result = nullptr) = 0; + + virtual ~GroupedMatmulPlan() {} + }; + + using MatmulPlanPtr = std::unique_ptr; + using GroupedMatmulPlanPtr = std::unique_ptr; + + virtual xla::Status Init() = 0; + + virtual xla::StatusOr GetMatmulPlan( + const GemmConfig& cfg) const = 0; + + virtual xla::StatusOr GetGroupedMatmulPlan( + Stream *stream, + const GroupedGemmConfig& config) const = 0; + + static BlasLt* Get(const Stream* stream); + + // convenience function to create MatmulPlan directly using stream + static xla::StatusOr GetMatmulPlan(const Stream* stream, + const GemmConfig& cfg); + + // convenience function to create GroupedMatmulPlan directly using stream + static xla::StatusOr CreateGroupedMatmulPlan( + Stream* stream, const GroupedGemmConfig& cfg); + + virtual ~BlasLt() {} +}; // class BlasLt + +struct GemmConfig { // plain GemmConfig which is extended with create functions + // in matmul_utils.h + MatrixLayout lhs_layout; + MatrixLayout rhs_layout; + MatrixLayout c_layout; + MatrixLayout output_layout; + xla::complex128 alpha; + double beta; + blas::AlgorithmType algorithm; + int64_t compute_precision; + BlasLt::Epilogue epilogue; +}; + +std::string ToCSVString(const GemmConfig& cfg, bool full_string); + + +} // namespace gpu + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc new file mode 100644 index 00000000000000..f2a6e59eb32784 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc @@ -0,0 +1,341 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/util/env_var.h" +#include "xla/util.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/shape_util.h" +#include "xla/debug_options_flags.h" +#include "xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_pimpl.h" + +namespace stream_executor { +namespace gpu { + +bool BlasLtGemmRunner::autotune_enabled_ = true; + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +std::ostream& operator <<(std::ostream& os, const StridedGemmConfig& cfg) { + return os << "trans_a/b: " << (int)cfg.trans_a << "/" << (int)cfg.trans_b << + " m: " << cfg.m << " n: " << cfg.n << " k: " << cfg.k << + " batch_count: " << cfg.batch_count << + " lda: " << cfg.lda << " ldb: " << cfg.ldb << " ldc: " << cfg.ldc << + " stride_a: " << cfg.stride_a << " stride_b: " << cfg.stride_b << + " stride_c: " << cfg.stride_c << + " type_a: " << (int)cfg.type_a << " type_b: " << (int)cfg.type_b << + " type_c: " << (int)cfg.type_c << + " alpha: " << cfg.alpha << " beta: " << cfg.beta; +} + +BlasLtGemmRunner::BlasLtGemmRunner(StreamExecutor *parent) : + mutex_(std::make_unique< absl::Mutex >()), + autotune_config_(std::make_unique< xla::gpu::AutotuneConfig >( + xla::gpu::DeviceConfig{parent, nullptr}, + xla::GetDebugOptionsFromFlags())) + { } + +BlasLtGemmRunner::~BlasLtGemmRunner() { } + + +/*static*/ BlasLtGemmRunner& BlasLtGemmRunner::i(const Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets a different cache instance + static std::vector> meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) { + autotune_enabled_ = xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0; + res.reset(new BlasLtGemmRunner(stream->parent())); + xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce(*res->autotune_config_); + } + return *res; +} + +template < class TuneFunc > +xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > BlasLtGemmRunner::Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func) { + gpu::BlasLt::MatmulAlgorithm best_algo; + float best_ms = std::numeric_limits< float >::max(), total_ms = 0; + uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0; + + for (uint32_t j = 0; j < algorithms.size(); j++) { + const auto& algo = algorithms[j]; + if (!benchmark_func(algo, nullptr).ok()) continue; + + blas::ProfileResult profile; + for (i = 0, total_ms = 0; i < n_total; i++) { + auto res = benchmark_func(algo, &profile); + if (!res.ok() || !profile.is_valid()) { + VLOG(1) << j << ": gemm algorithm is not valid: " /* << res.error_message() */; + break; + } + if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms(); + } + if (i < n_total) continue; // invalid algorithm + total_ms /= n_iters; + VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took " + << total_ms << "ms, workspace: " << algo.workspace_size; + if (total_ms < best_ms) { + best_ms = total_ms, best_algo = algo; + } + } // for algorithms + if (!best_algo.opaque_algo.has_value()) { + return xla::InternalError("No valid gemm algorithms found!"); + } + return best_algo; +} + +xla::StatusOr< std::array< uint64_t, 3 >> BlasLtGemmRunner::ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count) { + + uint64_t bsa = 0, bsb = 0, bsc = 0; + using CT = const uint8_t; + for(int64 i = 0; i < batch_count-1; i++) { + uint64_t da = (CT *)a[i + 1]->opaque() - (CT *)a[i]->opaque(), + db = (CT *)b[i + 1]->opaque() - (CT *)b[i]->opaque(), + dc = (CT *)c[i + 1]->opaque() - (CT *)c[i]->opaque(); + if(i == 0) { + bsa = da, bsb = db, bsc = dc; + } else if(!(bsa == da && bsb == db && bsc == dc)) { // strides mismatch + return xla::InternalError("Strides are not consistent!"); + } + } + return std::array< uint64_t, 3 >{ bsa, bsb, bsc }; +} + +xla::Status BlasLtGemmRunner::RunBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator) +{ + + TF_ASSIGN_OR_RETURN(auto compute_type, + gpu::GetBlasComputationType(type_a, type_c, 0)); + + GroupedGemmConfig cfg{ + .m = (int64)m, + .n = (int64)n, + .k = (int64)k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .type_d = type_c, + .lda = (int64)lda, + .ldb = (int64)ldb, + .ldc = (int64)ldc, + .ldd = (int64)ldc, + .compute_type = compute_type, + .a = a, + .b = b, + .c = const_cast< const void **>(c), + .d = c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = grouped_gemm_map_.find(cfg); + if (res == grouped_gemm_map_.end()) { + // NOTE: we assume that pointers a,b,c come from the device mem + // hence we need to block stream here + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::CreateGroupedMatmulPlan(&stream, cfg)); + res = grouped_gemm_map_.emplace(cfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << stream.parent() << ": new GGemm config: " << + grouped_gemm_map_.size() << " #valid algorithms: " << algorithms.size(); + + BlasLt::MatmulAlgorithm best_algo; + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No GG algorithms found!"); + best_algo = algorithms[0]; // otherwise use default algorithm + } else { + TF_ASSIGN_OR_RETURN(auto best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo, allocator); + } + return res->second->ExecuteOnStream(&stream, cfg, profile); + })); + } + TF_RETURN_IF_ERROR(res->second->SetAlgorithm(best_algo, allocator)); + } + return res->second->ExecuteOnStream(&stream, cfg); +} + +xla::Status BlasLtGemmRunner::RunStridedBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) +{ + StridedGemmConfig scfg{ + .m = m, + .n = n, + .k = k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .lda = lda, + .ldb = ldb, + .ldc = ldc, + .stride_a = stride_a, + .stride_b = stride_b, + .stride_c = stride_c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = strided_gemm_map_.find(scfg); + while (res == strided_gemm_map_.end()) { + int64 row_a = m, col_a = k, row_b = k, col_b = n; + if (trans_a == blas::Transpose::kTranspose) std::swap(row_a, col_a); + if (trans_b == blas::Transpose::kTranspose) std::swap(row_b, col_b); + + auto order = MatrixLayout::Order::kColumnMajor; + GemmConfig cfg = { + .lhs_layout = MatrixLayout(type_a, row_a, col_a, order, batch_count, + lda, stride_a, trans_a), + + .rhs_layout = MatrixLayout(type_b, row_b, col_b, order, batch_count, + ldb, stride_b, trans_b), + + .c_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .output_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .alpha = alpha, + .beta = beta, + .compute_precision = -1, + .epilogue = gpu::BlasLt::Epilogue::kDefault, + }; + + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::GetMatmulPlan(&stream, cfg)); + res = strided_gemm_map_.emplace(scfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << &stream << " dev " << stream.parent() << '(' << + stream.parent()->device_ordinal() << "): new StridedBatched config: " + << strided_gemm_map_.size() << " #algorithms: " << algorithms.size(); + + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No algorithms found!"); + res->second->SetAlgorithm(algorithms[0]); + break; + } + + BlasLt::MatmulAlgorithm best_algo{ .id = blas::kNoAlgorithm }; + xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/false)); + auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key); + if (opt_res.has_value()) { + auto id = *opt_res; + for (const auto& algo : algorithms) { + if (algo.id == id) best_algo = algo; + } + if (best_algo.id == blas::kNoAlgorithm) { + LOG(WARNING) << "Best algorithm not valid: need to autotune.."; + } + } + + if (best_algo.id == blas::kNoAlgorithm) { + TF_ASSIGN_OR_RETURN(best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo); + } + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator, // allocator + profile); + })); + xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id; + // reread algorithm ID from cache again (in case some other thread has + // already added this config to the cache to be sure we use the same ID) + auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache(key, ares, + *autotune_config_); + + if (new_id != best_algo.id) { + for (const auto& algo : algorithms) { + if (algo.id == new_id) best_algo = algo; + } + } + } // best_algo.id == blas::kNoAlgorithm + + res->second->SetAlgorithm(best_algo); + break; + } // while + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator); // allocator +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h new file mode 100644 index 00000000000000..4d317c5c7a1c93 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h @@ -0,0 +1,260 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "xla/util.h" + +using tensorflow::gtl::ArraySlice; +typedef ::std::int64_t int64; + + +namespace xla { +namespace gpu { +class AutotuneConfig; +} +} + +namespace stream_executor { + +namespace gpu { + +struct StridedGemmConfig { + int64 m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + xla::complex128 alpha; + double beta; + blas::DataType type_a, type_b, type_c; + int64 lda, ldb, ldc; + int64 stride_a, stride_b, stride_c; +}; + +namespace { + +auto AsTuple(const GroupedGemmConfig& p) { + // NOTE: alpha, beta and data pointers are not included in cache !! + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, + p.type_a, p.type_b, p.type_c, p.type_d, + p.lda, p.ldb, p.ldc, p.ldd, + p.compute_type); +} + +auto AsTuple(const StridedGemmConfig& p) { + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, p.alpha.real(), p.alpha.imag(), p.beta, + p.type_a, p.type_b, p.type_c, + p.lda, p.ldb, p.ldc, + p.stride_a, p.stride_b, p.stride_c); +} + +} // namespace + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs); +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs); + +template +H AbslHashValue(H h, const GroupedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +template +H AbslHashValue(H h, const StridedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +struct BlasLtGemmRunner { + + static BlasLtGemmRunner& i(const Stream *stream); + + template < class Scalar > + xla::complex128 Convert(Scalar x) { + if constexpr(std::is_same::value || + std::is_same::value) { + return static_cast< xla::complex128 >(x); + } else { + return static_cast< double >(x); + } + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, + const DeviceMemory& b, int64 ldb, + Scalar beta, DeviceMemory *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, 0, type_b, b, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, 0, 1, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, + const TypeB *b, int64 ldb, + Scalar beta, TypeC *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, DeviceMemoryBase{const_cast< TypeA *>(a)}, lda, 0, + type_b, DeviceMemoryBase{const_cast< TypeB *>(b)}, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, 0, 1, allocator); + } + + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, int64 stride_a, + const TypeB* b, int64 ldb, int64 stride_b, + Scalar beta, TypeC* c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, + DeviceMemoryBase{const_cast(a)}, lda, stride_a, type_b, + DeviceMemoryBase{const_cast(a)}, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, int64 stride_a, + const DeviceMemory& b, int64 ldb, int64 stride_b, + Scalar beta, DeviceMemory *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, stride_a, type_b, b, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, Scalar alpha, + const ArraySlice *> &a, int64 lda, + const ArraySlice *> &b, int64 ldb, Scalar beta, + const ArraySlice *> &c, int64 ldc, + int64 batch_count, ScratchAllocator* allocator) { + + // NOTE: Scalar types shall be verified for correctness vs T!! + auto type = dnn::ToDataType::value; + auto cvt = [](auto x){ + using TT = ArraySlice; + auto ptr = reinterpret_cast(&x); + return *reinterpret_cast(ptr); + }; + + auto res = ContiguousStrides(cvt(a), cvt(b), cvt(c), batch_count); + if (res.ok()) { + auto strides = std::move(res.value()); + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type, *a[0], lda, strides[0] / sizeof(T), + type, *b[0], ldb, strides[1] / sizeof(T), + Convert(beta).real(), // only real betas are supported!! + type, c[0], ldc, strides[2] / sizeof(T), batch_count, allocator); + } + return xla::InternalError("RunBatched: port::ArraySlice NYI!"); + } + + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, uint64 m, uint64 n, uint64 k, + Scalar alpha, const T** a, int lda, + const T** b, int ldb, Scalar beta, + T** c, int64 ldc, int64 batch_count, ScratchAllocator* allocator){ + + auto type = dnn::ToDataType::value; + return RunBatchedImpl(stream, trans_a, trans_b, m, n, k, + &alpha, type, reinterpret_cast< const void **>(a), lda, + type, reinterpret_cast< const void **>(b), ldb, &beta, + type, reinterpret_cast< void **>(c), ldc, batch_count, allocator); + } + + ~BlasLtGemmRunner(); + BlasLtGemmRunner& operator=(BlasLtGemmRunner&& rhs) noexcept = default; + BlasLtGemmRunner(BlasLtGemmRunner&& rhs) noexcept = default; + +private: + explicit BlasLtGemmRunner(StreamExecutor *parent); + + template < class TuneFunc > + xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func); + + + xla::Status RunBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator); + + xla::Status RunStridedBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator); + + xla::StatusOr< std::array< uint64_t, 3 >> ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count); + + static bool autotune_enabled_; + std::unique_ptr< absl::Mutex > mutex_; + std::unique_ptr< xla::gpu::AutotuneConfig > autotune_config_; + absl::flat_hash_map grouped_gemm_map_; + absl::flat_hash_map strided_gemm_map_; +}; + +} // namespace gpu + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 483d6286f26021..e442697d432270 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -290,7 +290,7 @@ class GpuDriver { // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control static tsl::Status LaunchKernel( - GpuContext* context, absl::string_view kernel_name, + GpuContext* context, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index b7fa33c87e5ace..eaa438fd5db882 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -37,7 +37,7 @@ class GpuKernel : public internal::KernelInterface { public: GpuKernel() : gpu_function_(nullptr), - arity_(0), + arity_(0), inprocess_(false), preferred_cache_config_(KernelCacheConfig::kNoPreference) {} // Note that the function is unloaded when the module is unloaded, and the @@ -55,6 +55,9 @@ class GpuKernel : public internal::KernelInterface { return const_cast(gpu_function_); } + void SetInProcessSymbol(bool inprocess) { inprocess_ = inprocess; } + bool IsInProcessSymbol() const { return inprocess_; } + // Returns the slot that the GpuFunctionHandle is stored within for this // object, for the CUDA API which wants to load into a GpuFunctionHandle*. GpuFunctionHandle* gpu_function_ptr() { return &gpu_function_; } @@ -82,6 +85,7 @@ class GpuKernel : public internal::KernelInterface { private: GpuFunctionHandle gpu_function_; // Wrapped CUDA kernel handle. unsigned arity_; // Number of formal parameters the kernel takes. + bool inprocess_; // Preferred (but not required) cache configuration for this kernel. KernelCacheConfig preferred_cache_config_; diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index c6415812a52ef2..d64b4f47ccc1fb 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -13,27 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include #include #include +#include #include +#include +#include -#include "absl/base/call_once.h" #include "absl/container/fixed_array.h" -#include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/asm_compiler.h" -#include "xla/stream_executor/gpu/gpu_asm_opts.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/framework/allocator.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h" +#include "tensorflow/compiler/xla/stream_executor/kernel.h" +#include "tensorflow/compiler/xla/stream_executor/launch_dim.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/math/math_util.h" + +#include "tensorflow/core/util/env_var.h" namespace stream_executor { @@ -41,7 +44,7 @@ namespace stream_executor { // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template static T RoundUpToNearest(T value, T divisor) { - return tsl::MathUtil::CeilOfRatio(value, divisor) * divisor; + return tensorflow::MathUtil::CeilOfRatio(value, divisor) * divisor; } // The size of the redzone at the end of the user buffer is rounded up to a @@ -50,17 +53,18 @@ constexpr int64_t kRhsRedzoneAlign = 4; using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; -RedzoneAllocator::RedzoneAllocator(Stream* stream, - DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts gpu_compilation_opts, - int64_t memory_limit, int64_t redzone_size, - uint8_t redzone_pattern) +RedzoneAllocator::RedzoneAllocator( + Stream* stream, + DeviceMemoryAllocator* memory_allocator, + const GpuAsmOpts& gpu_compilation_opts, + int64_t memory_limit, int64_t redzone_size, + uint8_t redzone_pattern) : device_ordinal_(stream->parent()->device_ordinal()), stream_(stream), memory_limit_(memory_limit), redzone_size_(RoundUpToNearest( redzone_size, - static_cast(tsl::Allocator::kAllocatorAlignment))), + static_cast(tensorflow::Allocator::kAllocatorAlignment))), redzone_pattern_(redzone_pattern), memory_allocator_(memory_allocator), gpu_compilation_opts_(gpu_compilation_opts) {} @@ -106,7 +110,7 @@ tsl::StatusOr> RedzoneAllocator::AllocateBytes( redzone_size_); uint8_t pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_, - redzone_pattern_}; + redzone_pattern_}; uint32_t pattern32; std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); @@ -119,66 +123,6 @@ tsl::StatusOr> RedzoneAllocator::AllocateBytes( return data_chunk; } -// PTX blob for the function which checks that every byte in -// input_buffer (length is buffer_length) is equal to redzone_pattern. -// -// On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. -// -// Generated from: -// __global__ void redzone_checker(unsigned char* input_buffer, -// unsigned char redzone_pattern, -// unsigned long long buffer_length, -// int* out_mismatched_ptr) { -// unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); -// } -// -// Code must compile for the oldest GPU XLA may be compiled for. -static const char* redzone_checker_ptx = R"( -.version 4.2 -.target sm_30 -.address_size 64 - -.visible .entry redzone_checker( - .param .u64 input_buffer, - .param .u8 redzone_pattern, - .param .u64 buffer_length, - .param .u64 out_mismatch_cnt_ptr -) -{ - .reg .pred %p<3>; - .reg .b16 %rs<3>; - .reg .b32 %r<6>; - .reg .b64 %rd<8>; - - ld.param.u64 %rd6, [buffer_length]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.u64.u32 %rd3, %r4; - setp.ge.u64 %p1, %rd3, %rd6; - @%p1 bra LBB6_3; - ld.param.u8 %rs1, [redzone_pattern]; - ld.param.u64 %rd4, [input_buffer]; - cvta.to.global.u64 %rd2, %rd4; - add.s64 %rd7, %rd2, %rd3; - ld.global.u8 %rs2, [%rd7]; - setp.eq.s16 %p2, %rs2, %rs1; - @%p2 bra LBB6_3; - ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; - cvta.to.global.u64 %rd1, %rd5; - atom.global.add.u32 %r5, [%rd1], 1; -LBB6_3: - ret; -} -)"; - -// The PTX in redzone_checker_ptx has to be launched with specified types -// in the specified order. -using ComparisonKernelT = TypedKernel, uint8_t, uint64_t, - DeviceMemory>; // Check that redzones weren't overwritten on a host. // @@ -188,8 +132,9 @@ static tsl::StatusOr CheckRedzoneHost( absl::string_view name, Stream* stream, uint8_t redzone_pattern) { uint64_t size = redzone.size(); auto redzone_data = std::make_unique(size); - TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size) - .BlockHostUntilDone()); + stream->ThenMemcpy(redzone_data.get(), redzone, size); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + std::array pattern_arr; pattern_arr.fill(redzone_pattern); @@ -217,10 +162,11 @@ static tsl::StatusOr CheckRedzoneHost( // Run the redzone checker on the provided buffer redzone. // // Increment out_param if mismatch occurs. -static tsl::Status RunRedzoneChecker( - Stream* stream, const DeviceMemory& redzone, - uint8_t redzone_pattern, const DeviceMemory& out_param, - const ComparisonKernelT& comparison_kernel) { +static tsl::Status RunRedzoneChecker(Stream* stream, + const DeviceMemory& redzone, + uint8_t redzone_pattern, + const DeviceMemory& out_param, + const ComparisonKernel& comparison_kernel) { StreamExecutor* executor = stream->parent(); if (redzone.size() == 0) { @@ -231,25 +177,27 @@ static tsl::Status RunRedzoneChecker( int64_t threads_per_block = std::min( executor->GetDeviceDescription().threads_per_block_limit(), num_elements); int64_t block_count = - tsl::MathUtil::CeilOfRatio(num_elements, threads_per_block); + tensorflow::MathUtil::CeilOfRatio(num_elements, threads_per_block); TF_RETURN_IF_ERROR(stream->ThenLaunch( - ThreadDim(threads_per_block), BlockDim(block_count), comparison_kernel, - redzone, redzone_pattern, redzone.size(), out_param)); - return ::tsl::OkStatus(); + ThreadDim(threads_per_block), BlockDim(block_count), + comparison_kernel, redzone, redzone_pattern, + redzone.size(), out_param)); + return tsl::OkStatus(); } // Since we reuse the same buffer for multiple checks, we re-initialize redzone // with a NaN pattern after a failed check. // // This function is blocking, since redzone failing is a rare event. -static tsl::Status ReinitializeRedzone(Stream* stream, DeviceMemoryBase redzone, - uint8_t redzone_pattern) { +static tsl::Status ReinitializeRedzone(Stream* stream, + DeviceMemoryBase redzone, + uint8_t redzone_pattern) { absl::FixedArray redzone_array(redzone.size()); redzone_array.fill(redzone_pattern); stream->ThenMemcpy(&redzone, redzone_array.data(), redzone.size()); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - return ::tsl::OkStatus(); + return tsl::OkStatus(); } // Check redzones around the user allocation. @@ -258,7 +206,7 @@ static tsl::Status ReinitializeRedzone(Stream* stream, DeviceMemoryBase redzone, static tsl::StatusOr CheckRedzonesForBuffer( Stream* stream, DeviceMemoryBase memory, const DeviceMemory& out_param, - const ComparisonKernelT& comparison_kernel, int64_t user_allocation_size, + const ComparisonKernel& comparison_kernel, int64_t user_allocation_size, uint64_t redzone_size, uint8_t redzone_pattern) { StreamExecutor* executor = stream->parent(); int64_t rhs_slop = @@ -277,10 +225,10 @@ static tsl::StatusOr CheckRedzonesForBuffer( executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size, /*element_count=*/redzone_size + rhs_slop); - TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, - out_param, comparison_kernel)); - TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, rhs_redzone, redzone_pattern, - out_param, comparison_kernel)); + TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, out_param, + comparison_kernel)); + TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, rhs_redzone, redzone_pattern, out_param, + comparison_kernel)); int64_t result; CHECK_EQ(out_param.size(), sizeof(result)); stream->ThenMemcpy(&result, out_param, sizeof(result)); @@ -308,47 +256,36 @@ static tsl::StatusOr CheckRedzonesForBuffer( } tsl::StatusOr RedzoneAllocator::CheckRedzones() const { + // add for PPU + static bool found_ppu_device = [] { + bool found_ppu = false; + if (!tensorflow::ReadBoolFromEnvVar("FOUND_PPU_DEVICE", false, &found_ppu).ok()) { + return false; + } + if (found_ppu) { + return true; + } + return false; + }(); + if (found_ppu_device) { + return RedzoneCheckStatus::OK(); + } StreamExecutor* executor = stream_->parent(); - absl::Span compiled_ptx = {}; - tsl::StatusOr> compiled_ptx_or = - CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx, - gpu_compilation_opts_); - if (compiled_ptx_or.ok()) { - compiled_ptx = compiled_ptx_or.value(); - } else { - static absl::once_flag ptxas_not_found_logged; - absl::call_once(ptxas_not_found_logged, [&]() { - LOG(WARNING) << compiled_ptx_or.status() - << "\nRelying on driver to perform ptx compilation. " - << "\nModify $PATH to customize ptxas location." - << "\nThis message will be only logged once."; - }); - } + TF_ASSIGN_OR_RETURN( + const ComparisonKernel* kernel, + GetComparisonKernel(stream_->parent(), gpu_compilation_opts_)); ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); stream_->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); -#if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN( - std::shared_ptr loaded_kernel, - (LoadKernelOrGetPtr, uint8_t, uint64_t, - DeviceMemory>( - executor, "redzone_checker", redzone_checker_ptx, compiled_ptx))); -#else - TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded_kernel, - (executor->CreateTypedKernel, uint8, uint64_t, - DeviceMemory>( - "redzone_checker", redzone_checker_ptx, compiled_ptx))); -#endif // GOOGLE_CUDA for (const auto& buf_and_size : allocated_buffers_) { TF_ASSIGN_OR_RETURN( RedzoneCheckStatus redzone_status, CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(), - *loaded_kernel, buf_and_size.second, + *kernel, buf_and_size.second, redzone_size_, redzone_pattern_)); if (!redzone_status.ok()) { return redzone_status; @@ -365,4 +302,4 @@ std::string RedzoneCheckStatus::RedzoneFailureMsg() const { buffer_name, user_buffer_address, offset, expected_value, actual_value); } -} // namespace stream_executor +} // namespace stream_executor \ No newline at end of file diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h index 43694d3295c386..e2fcd9795d5794 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h @@ -13,18 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ -#define XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ #include +#include +#include #include -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/gpu/asm_compiler.h" -#include "xla/stream_executor/gpu/gpu_asm_opts.h" -#include "xla/stream_executor/scratch_allocator.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/math/math_util.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" + namespace stream_executor { @@ -43,12 +46,12 @@ class RedzoneAllocator : public ScratchAllocator { public: static constexpr int64_t kDefaultRedzoneSize = 1LL << 23; // 8MiB per side, 16MiB total. - static constexpr uint8_t kDefaultRedzonePattern = -1; // NOLINT + static constexpr uint8 kDefaultRedzonePattern = -1; RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts gpu_compilation_opts_, - int64_t memory_limit = (1LL << 32), // 4GB + const GpuAsmOpts& gpu_compilation_opts_, + int64_t memory_limit = (1LL << 32), int64_t redzone_size = kDefaultRedzoneSize, - uint8_t redzone_pattern = kDefaultRedzonePattern); + uint8 redzone_pattern = kDefaultRedzonePattern); // Redzones don't count towards the memory limit. int64_t GetMemoryLimitInBytes() override { return memory_limit_; } @@ -65,8 +68,7 @@ class RedzoneAllocator : public ScratchAllocator { RedzoneCheckStatus() = default; RedzoneCheckStatus(absl::string_view buffer_name, void* user_buffer_address, - int64_t offset, uint64_t expected_value, - uint64_t actual_value) + int64_t offset, uint64 expected_value, uint64 actual_value) : buffer_name(buffer_name), user_buffer_address(user_buffer_address), offset(offset), @@ -82,8 +84,8 @@ class RedzoneAllocator : public ScratchAllocator { std::string buffer_name = {}; void* user_buffer_address = nullptr; int64_t offset = 0; - uint64_t expected_value = 0; - uint64_t actual_value = 0; + uint64 expected_value = 0; + uint64 actual_value = 0; }; // Determines whether redzones around all allocated buffers are unmodified. @@ -98,7 +100,6 @@ class RedzoneAllocator : public ScratchAllocator { // redzone has been detected. // - A stream error, if loading or launching the kernel has failed. tsl::StatusOr CheckRedzones() const; - Stream* stream() const { return stream_; } private: @@ -114,7 +115,7 @@ class RedzoneAllocator : public ScratchAllocator { // returned to users will be misaligned. const int64_t redzone_size_; - const uint8_t redzone_pattern_; + const uint8 redzone_pattern_; DeviceMemoryAllocator* memory_allocator_; GpuAsmOpts gpu_compilation_opts_; @@ -132,4 +133,4 @@ class RedzoneAllocator : public ScratchAllocator { } // namespace stream_executor -#endif // XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ \ No newline at end of file diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h new file mode 100644 index 00000000000000..f7a1e59b3b6476 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_KERNEL_H_ + +#include + +#include "tensorflow/tsl/platform/statusor.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/stream_executor.h" + +namespace stream_executor { +using ComparisonKernel = TypedKernel, uint8_t, uint64_t, + DeviceMemory>; + +// Returns a GPU kernel that checks a memory location for redzone patterns. +// Parameters are (buffer_address, redzone_pattern, buffer_length, +// mismatch_count_ptr). For each byte in buffer `[buffer_address : +// buffer_address +// + buffer_length]` that is not equal to `redzone_pattern`, +// `*mismatch_count_ptr` gets incremented by 1. +tsl::StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts gpu_asm_opts); + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc new file mode 100644 index 00000000000000..cec8a09aaee6e8 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc @@ -0,0 +1,146 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/stream_executor/lib/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_asm_compiler.h" +#include "xla/stream_executor/cuda/cuda_driver.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/redzone_allocator_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor { +// Maintains a cache of pointers to loaded kernels +template +static StatusOr*> LoadKernelOrGetPtr( + StreamExecutor* executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data) { + using KernelPtrCacheKey = + std::tuple; + + static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); + static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = + *new absl::node_hash_map>(); + CUcontext current_context = cuda::CurrentContextOrDie(); + KernelPtrCacheKey kernel_ptr_cache_key{current_context, kernel_name, ptx}; + absl::MutexLock lock(&kernel_ptr_cache_mutex); + + auto it = kernel_ptr_cache.find(kernel_ptr_cache_key); + if (it == kernel_ptr_cache.end()) { + TF_ASSIGN_OR_RETURN(TypedKernel loaded, + (TypedKernelFactory::Create( + executor, kernel_name, ptx, cubin_data))); + it = + kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first; + } + + CHECK(it != kernel_ptr_cache.end()); + return &it->second; +} + +// PTX blob for the function which checks that every byte in +// input_buffer (length is buffer_length) is equal to redzone_pattern. +// +// On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. +// +// Generated from: +// __global__ void redzone_checker(unsigned char* input_buffer, +// unsigned char redzone_pattern, +// unsigned long long buffer_length, +// int* out_mismatched_ptr) { +// unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); +// } +// +// Code must compile for the oldest GPU XLA may be compiled for. +static const char* redzone_checker_ptx = R"( +.version 4.2 +.target sm_30 +.address_size 64 + +.visible .entry redzone_checker( + .param .u64 input_buffer, + .param .u8 redzone_pattern, + .param .u64 buffer_length, + .param .u64 out_mismatch_cnt_ptr +) +{ + .reg .pred %p<3>; + .reg .b16 %rs<3>; + .reg .b32 %r<6>; + .reg .b64 %rd<8>; + + ld.param.u64 %rd6, [buffer_length]; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r3, %r2, %r1; + cvt.u64.u32 %rd3, %r4; + setp.ge.u64 %p1, %rd3, %rd6; + @%p1 bra LBB6_3; + ld.param.u8 %rs1, [redzone_pattern]; + ld.param.u64 %rd4, [input_buffer]; + cvta.to.global.u64 %rd2, %rd4; + add.s64 %rd7, %rd2, %rd3; + ld.global.u8 %rs2, [%rd7]; + setp.eq.s16 %p2, %rs2, %rs1; + @%p2 bra LBB6_3; + ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; + cvta.to.global.u64 %rd1, %rd5; + atom.global.add.u32 %r5, [%rd1], 1; +LBB6_3: + ret; +} +)"; + +StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts gpu_asm_opts) { + absl::Span compiled_ptx = {}; + StatusOr> compiled_ptx_or = + CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx, + gpu_asm_opts); + if (compiled_ptx_or.ok()) { + compiled_ptx = compiled_ptx_or.value(); + } else { + static absl::once_flag ptxas_not_found_logged; + absl::call_once(ptxas_not_found_logged, [&]() { + LOG(WARNING) << compiled_ptx_or.status() + << "\nRelying on driver to perform ptx compilation. " + << "\nModify $PATH to customize ptxas location." + << "\nThis message will be only logged once."; + }); + } + + return LoadKernelOrGetPtr, uint8_t, uint64_t, + DeviceMemory>( + executor, "redzone_checker", redzone_checker_ptx, compiled_ptx); +} +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc new file mode 100644 index 00000000000000..bb916a46c63652 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/redzone_allocator_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/stream_executor.h" + +namespace { +__global__ void redzone_checker_kernel(uint8_t* input_buffer, + uint8_t redzone_pattern, + uint64_t buffer_length, + uint32_t* out_mismatched_ptr) { + uint64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); +} +} // namespace + +namespace stream_executor { + +tsl::StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts /*gpu_asm_opts*/) { + + static auto kernel = + executor->CreateTypedKernel, uint8_t, uint64_t, + DeviceMemory>( + "redzone_checker", reinterpret_cast( + redzone_checker_kernel)); + + if (!kernel.ok()) return kernel.status(); + return kernel.value().get(); +} + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc new file mode 100644 index 00000000000000..4068a16a0e65bd --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "xla/stream_executor/gpu/redzone_allocator.h" + +#include +#include + +#include "xla/stream_executor/lib/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu_asm_opts.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { +namespace gpu { +namespace { + +using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; + +static void EXPECT_REDZONE_OK(StatusOr status) { + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(status.value().ok()); +} + +static void EXPECT_REDZONE_VIOLATION( + StatusOr status) { + EXPECT_TRUE(status.ok()); + EXPECT_FALSE(status.value().ok()); +} + +TEST(RedzoneAllocatorTest, WriteToRedzone) { + constexpr int64_t kRedzoneSize = 1 << 23; // 8MiB redzone on each side + // Redzone pattern should not be equal to zero; otherwise modify_redzone will + // break. + constexpr uint8_t kRedzonePattern = 0x7e; + + // Allocate 32MiB + 1 byte (to make things misaligned) + constexpr int64_t kAllocSize = (1 << 25) + 1; + + Platform* platform = + PlatformManager::PlatformWithName(tensorflow::GpuPlatformName()).value(); + StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); + GpuAsmOpts opts; + StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_exec->CreateStream()); + RedzoneAllocator allocator(stream.get(), &se_allocator, opts, + /*memory_limit=*/(1LL << 32), + /*redzone_size=*/kRedzoneSize, + /*redzone_pattern=*/kRedzonePattern); + TF_ASSERT_OK_AND_ASSIGN(DeviceMemory buf, + allocator.AllocateBytes(/*byte_size=*/kAllocSize)); + + EXPECT_REDZONE_OK(allocator.CheckRedzones()); + + char* buf_addr = reinterpret_cast(buf.opaque()); + DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize); + DeviceMemoryBase rhs_redzone(buf_addr + kAllocSize, kRedzoneSize); + + // Check that the redzones are in fact filled with kRedzonePattern. + auto check_redzone = [&](DeviceMemoryBase redzone, absl::string_view name) { + std::vector host_buf(kRedzoneSize); + TF_ASSERT_OK(stream->ThenMemcpy(host_buf.data(), redzone, kRedzoneSize)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + const int64_t kMaxMismatches = 16; + int64_t mismatches = 0; + for (int64_t i = 0; i < host_buf.size(); ++i) { + if (mismatches == kMaxMismatches) { + ADD_FAILURE() << "Hit max number of mismatches; skipping others."; + break; + } + if (host_buf[i] != kRedzonePattern) { + ++mismatches; + EXPECT_EQ(host_buf[i], kRedzonePattern) + << "at index " << i << " of " << name << " redzone"; + } + } + }; + + check_redzone(lhs_redzone, "lhs"); + check_redzone(rhs_redzone, "rhs"); + + // Modifies a redzone, checks that RedzonesAreUnmodified returns false, then + // reverts it back to its original value and checks that RedzonesAreUnmodified + // returns true. + auto modify_redzone = [&](DeviceMemoryBase redzone, int64_t offset, + absl::string_view name) { + SCOPED_TRACE(absl::StrCat(name, ", offset=", offset)); + DeviceMemoryBase redzone_at_offset( + reinterpret_cast(redzone.opaque()) + offset, 1); + char old_redzone_value = 0; + { EXPECT_REDZONE_OK(allocator.CheckRedzones()); } + TF_ASSERT_OK(stream->ThenMemcpy(&old_redzone_value, redzone_at_offset, 1)); + TF_ASSERT_OK(stream->MemZero(&redzone_at_offset, 1)); + EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones()); + + // Checking reinitializes the redzone. + EXPECT_REDZONE_OK(allocator.CheckRedzones()); + }; + + modify_redzone(lhs_redzone, /*offset=*/0, "lhs"); + modify_redzone(lhs_redzone, /*offset=*/kRedzoneSize - 1, "lhs"); + modify_redzone(rhs_redzone, /*offset=*/0, "rhs"); + modify_redzone(rhs_redzone, /*offset=*/kRedzoneSize - 1, "rhs"); +} + +// Older CUDA compute capabilities (<= 2.0) have a limitation that grid +// dimension X cannot be larger than 65535. +// +// Make sure we can launch kernels on sizes larger than that, given that the +// maximum number of threads per block is 1024. +TEST(RedzoneAllocatorTest, VeryLargeRedzone) { + // Make sure the redzone size would require grid dimension > 65535. + constexpr int64_t kRedzoneSize = 65535 * 1024 + 1; + Platform* platform = + PlatformWithName(tensorflow::GpuPlatformName()).value(); + + StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); + GpuAsmOpts opts; + StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_exec->CreateStream()); + RedzoneAllocator allocator( + stream.get(), &se_allocator, opts, + /*memory_limit=*/ (1LL << 32), + /*redzone_size=*/kRedzoneSize, + /*redzone_pattern=*/-1); + (void)allocator.AllocateBytes(/*byte_size=*/1); + EXPECT_REDZONE_OK(allocator.CheckRedzones()); +} + +} // namespace gpu +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/kernel_spec.cc b/third_party/xla/xla/stream_executor/kernel_spec.cc index f0fabf44d5399e..92e86dbd9bac22 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.cc +++ b/third_party/xla/xla/stream_executor/kernel_spec.cc @@ -31,6 +31,10 @@ namespace stream_executor { KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernel_name) : kernel_name_(std::string(kernel_name)) {} +InProcessSymbol::InProcessSymbol(void *symbol, absl::string_view kernel_name) + : KernelLoaderSpec(kernel_name), symbol_(symbol) {} + + OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename, absl::string_view kernel_name) : KernelLoaderSpec(kernel_name), filename_(std::string(filename)) {} @@ -166,6 +170,14 @@ const char *CudaPtxInMemory::original_text(int compute_capability_major, return ptx_iter->second; } +MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddInProcessSymbol( + void *symbol, absl::string_view kernel_name) { + CHECK(in_process_symbol_ == nullptr); + in_process_symbol_ = + std::make_shared(symbol, std::string(kernel_name)); + return this; +} + MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxOnDisk( absl::string_view filename, absl::string_view kernel_name) { CHECK(cuda_ptx_on_disk_ == nullptr); diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index 68f45a0ee07f4d..b7af7b67b7bdcf 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -88,6 +88,18 @@ class KernelLoaderSpec { void operator=(const KernelLoaderSpec &) = delete; }; +// Loads kernel from in process symbol pointer (e.g. pointer to C++ device +// function). +class InProcessSymbol : public KernelLoaderSpec { +public: + InProcessSymbol(void *symbol, absl::string_view kernel_name); + + void *symbol() const { return symbol_; } + +private: + void *symbol_; +}; + // An abstract kernel loader spec that has an associated file path, where // there's a canonical suffix for the filename; e.g. see CudaPtxOnDisk whose // canonical filename suffix is ".ptx". @@ -246,6 +258,7 @@ class MultiKernelLoaderSpec { // Convenience getters for testing whether these platform variants have // kernel loader specifications available. + bool has_in_process_symbol() const { return in_process_symbol_ != nullptr; } bool has_cuda_ptx_on_disk() const { return cuda_ptx_on_disk_ != nullptr; } bool has_cuda_cubin_on_disk() const { return cuda_cubin_on_disk_ != nullptr; } bool has_cuda_cubin_in_memory() const { @@ -255,6 +268,10 @@ class MultiKernelLoaderSpec { // Accessors for platform variant kernel load specifications. // Precondition: corresponding has_* is true. + const InProcessSymbol &in_process_symbol() const { + CHECK(has_in_process_symbol()); + return *in_process_symbol_; + } const CudaPtxOnDisk &cuda_ptx_on_disk() const { CHECK(has_cuda_ptx_on_disk()); return *cuda_ptx_on_disk_; @@ -278,6 +295,8 @@ class MultiKernelLoaderSpec { // Note that the kernel_name parameter must be consistent with the kernel in // the PTX being loaded. Also be aware that in CUDA C++ the kernel name may be // mangled by the compiler if it is not declared in an extern "C" scope. + MultiKernelLoaderSpec *AddInProcessSymbol(void *symbol, + absl::string_view kernel_name); MultiKernelLoaderSpec *AddCudaPtxOnDisk(absl::string_view filename, absl::string_view kernel_name); MultiKernelLoaderSpec *AddCudaCubinOnDisk(absl::string_view filename, @@ -296,6 +315,8 @@ class MultiKernelLoaderSpec { absl::string_view kernel_name); private: + std::shared_ptr + in_process_symbol_; // In process symbol pointer. std::unique_ptr cuda_ptx_on_disk_; // PTX text that resides in a file. std::unique_ptr diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 4cd6514d267d1c..37a87e7450de28 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -210,6 +210,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/platform", "//xla/stream_executor:blas", "//xla/stream_executor/platform:dso_loader", diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cu.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cu.cc new file mode 100644 index 00000000000000..b21e116819b6d5 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cu.cc @@ -0,0 +1,58 @@ +/* Copyright 2023 The OpenXLA Authors. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ +#define LEGACY_HIPBLAS_DIRECT +#include +#include "rocm/rocm_config.h" +#include "rocm/include/hipblaslt/hipblaslt-ext.hpp" +#include "rocm/include/hipblaslt/hipblaslt.h" + +namespace stream_executor { +namespace rocm { + +using namespace hipblaslt_ext; + +namespace { + +__global__ void CopyUserArgsKernel(UserArguments *dest_args, + const void **a, const void **b, const void **c, void **d, + uint32_t num_gemms) +{ + uint32_t idx = blockIdx.x*blockDim.x + threadIdx.x; + if(idx < num_gemms) { + // writing ArrayOfStructs is not optimal.. + auto arg = dest_args[idx]; + arg.a = const_cast< void *>(a[idx]); + arg.b = const_cast< void *>(b[idx]); + arg.c = const_cast< void *>(c[idx]); + arg.d = d[idx]; + //printf("idx: %d %p %p %p %p\n", idx, arg.a, arg.b, arg.c, arg.d); + } +} +} // namespace + +void GroupGemmUpdateArgs(hipStream_t stream, + UserArguments *dev_args, + //const gpu::GroupedGemmConfig& cfg + const void **a, const void **b, const void **c, void **d, + uint32_t num_gemms) { + + const uint32_t block_sz = 128, + n_blocks = (num_gemms + block_sz - 1)/block_sz; + hipLaunchKernelGGL(CopyUserArgsKernel, n_blocks, + std::min(block_sz, num_gemms), 0, + stream, + dev_args, + //static_cast< UserArguments *>(device_args_.opaque()), + a, b, c, d, num_gemms); +} +} // namespace rocm +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc index 60597e5fe7318e..26167888fbb782 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc @@ -132,18 +132,13 @@ ROCMBlas::~ROCMBlas() { } bool ROCMBlas::SetStream(Stream *stream) { - CHECK(stream != nullptr); - CHECK(AsGpuStreamValue(stream) != nullptr); - CHECK(blas_ != nullptr); - gpu::ScopedActivateExecutorContext sac{parent_}; - rocblas_status ret = - wrap::rocblas_set_stream(blas_, AsGpuStreamValue(stream)); + wrap::rocblas_set_stream(blas_, + stream != nullptr ? AsGpuStreamValue(stream) : 0); if (ret != rocblas_status_success) { LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); return false; } - return true; } @@ -212,13 +207,13 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, Args... args) { absl::MutexLock lock{&mu_}; + gpu::ScopedActivateExecutorContext sac{parent_}; + CHECK(blas_ != nullptr); if (!SetStream(stream)) { return false; } - gpu::ScopedActivateExecutorContext sac{parent_}; - // set the atomics mode, leaving default to library bool allow_atomics = !OpDeterminismRequired(); rocblas_status ret; @@ -231,6 +226,7 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, } ret = rocblas_func(blas_, args...); + SetStream(nullptr); if (err_on_failure && ret != rocblas_status_success) { LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " << ToString(ret); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h index a6ec9e0ad6cdd0..3565cb0cf27ef5 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h @@ -94,7 +94,9 @@ class ROCMBlas : public blas::BlasSupport { TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES #if TF_HIPBLASLT - rocm::BlasLt &blas_lt() { return blas_lt_; } + gpu::BlasLt *GetBlasLt() override { + return &blas_lt_; + } #endif private: diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index ea6ea59eedd560..73c88e94176ae6 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -750,23 +750,34 @@ GpuDriver::GraphNodeGetType(hipGraphNode_t node) { } /* static */ tsl::Status GpuDriver::LaunchKernel( - GpuContext* context, absl::string_view kernel_name, hipFunction_t function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, unsigned int shared_mem_bytes, - GpuStreamHandle stream, void** kernel_params, void** extra) { + GpuContext* context, hipFunction_t function, unsigned int grid_dim_x, + unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, + unsigned int block_dim_y, unsigned int block_dim_z, + unsigned int shared_mem_bytes, GpuStreamHandle stream, void** kernel_params, + void** extra) { ScopedActivateContext activation{context}; - VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x + VLOG(2) << "launching kernel: " << function << "; gdx: " << grid_dim_x << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z << " bdx: " << block_dim_x << " bdy: " << block_dim_y << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes; - RETURN_IF_ROCM_ERROR(wrap::hipModuleLaunchKernel( - function, grid_dim_x, grid_dim_y, grid_dim_z, - block_dim_x, block_dim_y, block_dim_z, - shared_mem_bytes, stream, kernel_params, extra), - "Failed to launch ROCm kernel: ", kernel_name, - " with block dimensions: ", block_dim_x, "x", - block_dim_y, "x", block_dim_z); + + // for in-process kernels we use non-null kernel_params: + auto res = hipSuccess; + if (kernel_params != nullptr) { + res = wrap::hipLaunchKernel((const void*)function, + dim3(grid_dim_x, grid_dim_y, grid_dim_z), + dim3(block_dim_x, block_dim_y, block_dim_z), + kernel_params, shared_mem_bytes, stream); + } else { + res = wrap::hipModuleLaunchKernel( + function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, + block_dim_z, shared_mem_bytes, stream, nullptr, extra); + } + + if (res != hipSuccess) { + return tsl::errors::Internal( + absl::StrCat("Failed to launch ROCM kernel: ", ToString(res))); + } VLOG(2) << "successfully launched kernel"; return tsl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index 5808f4c266ce85..9f14893756a7bf 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -114,6 +114,7 @@ namespace wrap { __macro(hipHostUnregister) \ __macro(hipInit) \ __macro(hipLaunchHostFunc) \ + __macro(hipLaunchKernel) \ __macro(hipMalloc) \ __macro(hipMemGetAddressRange) \ __macro(hipMemGetInfo) \ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc index d5fa92a69f06c5..b91ea684bd398a 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -222,21 +222,37 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); } kernel_to_gpu_binary_[kernel] = hsaco; + } else if (spec.has_in_process_symbol()) { + kernelname = &spec.in_process_symbol().kernelname(); + void* symbol = spec.in_process_symbol().symbol(); + + VLOG(1) << "Resolve ROCM kernel " << *kernelname + << " from symbol pointer: " << symbol; + + *rocm_kernel->gpu_function_ptr() = static_cast(symbol); + rocm_kernel->SetInProcessSymbol(true); } else { return tsl::errors::Internal("No method of loading ROCM kernel provided"); } - VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( + // If we resolved kernel from a symbol pointer, there is no need to load it + // from a module, as ROCm runtime did that automatically for us. + if (!spec.has_in_process_symbol()){ + VLOG(2) << "getting function " << *kernel_name << " from module " << module; + TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( context_, module, kernel_name->c_str(), rocm_kernel->gpu_function_ptr())); + } // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the ROCM API. rocm_kernel->set_arity(spec.arity()); - KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); + // unable to get kernel metadata for in-process kernel + if (!spec.has_in_process_symbol()) { + KernelMetadata kernel_metadata; + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); + kernel->set_metadata(kernel_metadata); + } kernel->set_name(*kernel_name); return tsl::OkStatus(); } @@ -294,17 +310,25 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, VLOG(2) << "*(arg.address): " << reinterpret_cast( *static_cast(arg.address)); - kernargs.push_back( + // in-process kernel launches need a list of pointers to the params + kernargs.push_back(rocm_kernel->IsInProcessSymbol() ? + const_cast(arg.address) : reinterpret_cast(*static_cast(arg.address))); } + if(rocm_kernel->IsInProcessSymbol()) { + return GpuDriver::LaunchKernel( + GetGpuContext(stream), hipfunc, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + args.number_of_shared_bytes(), hipstream, kernargs.data(), nullptr); + } size_t size = sizeof(void*) * kernargs.size(); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernargs.data(), HIP_LAUNCH_PARAM_BUFFER_SIZE, &size, HIP_LAUNCH_PARAM_END}; return GpuDriver::LaunchKernel( - GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + GetGpuContext(stream), hipfunc, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, args.number_of_shared_bytes(), hipstream, nullptr, (void**)&config); } diff --git a/third_party/xla/xla/stream_executor/stream.cc b/third_party/xla/xla/stream_executor/stream.cc index 25916abc2469ce..bb71a139af5730 100644 --- a/third_party/xla/xla/stream_executor/stream.cc +++ b/third_party/xla/xla/stream_executor/stream.cc @@ -1193,6 +1193,117 @@ Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, x, incx, beta, y, incy); } +template +tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64 k, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + blas::ComputePrecision precision, + blas::CallContext context) { + InputType alpha{1.0}; + InputType beta{0.0}; + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, precision, context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM(INPUT_TYPE) \ + template tsl::Status Stream::ThenBlasGemm( \ + blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64 n, uint64 k, \ + const DeviceMemory& a, int lda, \ + const DeviceMemory& b, int ldb, \ + DeviceMemory* c, int ldc, \ + blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM(float) +INSTANTIATE_THEN_BLAS_GEMM(double) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM(std::complex) +INSTANTIATE_THEN_BLAS_GEMM(std::complex) + +#undef INSTANTIATE_THEN_BLAS_GEMM + +template +tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64 k, ConstantType alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + ConstantType beta, DeviceMemory *c, + int ldc, blas::ComputePrecision precision, + blas::CallContext context) { + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + static_assert( + detail::is_any_of, std::complex>(), + "Input can be half, bf16, float, double, std::complex or " + "std::complex"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::half, constant has to be either " + "Eigen::half or float"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::bfloat16, constant has to be either " + "Eigen::bfloat16 or float"); + static_assert( + detail::is_any_of(), + "If input is not Eigen::half, constant and input types have to match"); + blas::BlasSupport *blas = parent()->AsBlas(); + if (!blas) { + return tsl::errors::Internal( + "Attempting to perform BLAS operation using " + "StreamExecutor without BLAS support"); + } + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return blas->DoBlasGemm(this, transa, transb, m, n, k, + blas::ToDataType::value, alpha_ptr, a, + lda, b, ldb, beta_ptr, c, ldc, precision, + context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM(INPUT_TYPE, CONSTANT_TYPE) \ + template tsl::Status Stream::ThenBlasGemm( \ + blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64 n, uint64 k, CONSTANT_TYPE alpha, \ + const DeviceMemory& a, int lda, \ + const DeviceMemory& b, int ldb, \ + CONSTANT_TYPE beta, DeviceMemory* c, int ldc, \ + blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM(float, float) +INSTANTIATE_THEN_BLAS_GEMM(double, double) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half, Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half, float) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, float) + +#undef INSTANTIATE_THEN_BLAS_GEMM + namespace { // Like ThenBlasImpl, except this expects the last argument of blas_func to be a // blas::ProfileResult*. This functor doesn't put the stream into an error diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h index f8903f1c946ff5..2b36c2a675bf3e 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h @@ -388,6 +388,11 @@ class StreamExecutor { absl::string_view kernel_name, absl::string_view ptx, absl::Span cubin_data); + // Same as above but for in-process kernels + template + tsl::StatusOr>> CreateTypedKernel( + absl::string_view kernel_name, void *symbol); + // Warning: use Stream::ThenLaunch instead, this method is not for general // consumption. However, this is the only way to launch a kernel for which // the type signature is only known at runtime; say, if an application @@ -701,6 +706,18 @@ StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, return std::move(kernel_base); } +template +inline tsl::StatusOr>> +StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, + void *symbol) { + auto kernel_base = absl::make_unique>(this); + MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); + loader_spec.AddInProcessSymbol(symbol, kernel_name); + + TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get())); + return std::move(kernel_base); +} + template inline DeviceMemory StreamExecutor::AllocateArray(uint64_t element_count, int64_t memory_space) { diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index dce43582121e28..e490347398917b 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -547,10 +547,6 @@ tsl::StatusOr LhloDialectEmitter::EmitCustomCallOp( return EmitCublasLtMatmul(custom_call_instr); } - if (xla::gpu::IsCublasLtMatmulF8(*instr)) { - return EmitCublasLtMatmulF8(custom_call_instr); - } - if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) { return EmitDnnConvolution(custom_call_instr); } @@ -802,19 +798,19 @@ tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmul( TF_ASSIGN_OR_RETURN( bool has_vector_bias, - xla::gpu::cublas_lt::EpilogueAddsVectorBias(config.epilogue())); + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); TF_ASSIGN_OR_RETURN( bool has_aux_output, - xla::gpu::cublas_lt::EpilogueHasAuxiliaryOutput(config.epilogue())); + xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(config.epilogue())); TF_RET_CHECK(custom_call->operand_count() == 2 + int{has_matrix_bias} + int{has_vector_bias}); xla::ShapeIndex output_index = - has_aux_output ? xla::ShapeIndex{0} : xla::ShapeIndex{}; + has_aux_output ? xla::ShapeIndex{0, 0} : xla::ShapeIndex{0}; - llvm::SmallVector operands; + llvm::SmallVector operands; TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); if (has_matrix_bias) { @@ -830,15 +826,18 @@ tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmul( } if (has_aux_output) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, {1})); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, + xla::ShapeIndex{0, 1})); } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, + xla::ShapeIndex{1})); auto op = CreateOpWithoutAttrs(custom_call, operands); SetMatmulAttributes(op, config, builder_); int32_t operand_sizes[] = { - 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_aux_output ? 1 : 0}; + 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_aux_output ? 1 : 0, 1}; op->setAttr(op.getOperandSegmentSizeAttr(), builder_.getDenseI32ArrayAttr(operand_sizes)); @@ -866,7 +865,7 @@ tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmulF8( TF_RET_CHECK(ops_num == 6 || ops_num == 7 || ops_num == 8); TF_ASSIGN_OR_RETURN( bool has_vector_bias, - xla::gpu::cublas_lt::EpilogueAddsVectorBias(config.epilogue())); + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); bool has_damax = custom_call->shape().IsTuple(); bool has_matrix_bias = config.beta() != 0.; diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 7e01954d0c2475..17a246d76ae2c7 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -643,7 +643,33 @@ message DebugOptions { // Maximum number of buffers to print when debugging buffer assignment. int64 xla_debug_buffer_assignment_show_max = 251; - // Next id: 256 + // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of + // elements of both matrices in non-batch dimensions to be considered for a + // rewrite). + int64 xla_gpu_gemm_rewrite_size_threshold = 256; + + // File to write autotune results to. It will be a binary file unless the name + // ends with .txt or .textproto. Warning: The results are written at every + // compilation, possibly multiple times per process. This only works on CUDA. + string xla_gpu_dump_autotune_results_to = 257; + + // File to load autotune results from. It will be considered a binary file + // unless the name ends with .txt or .textproto. At most one loading will + // happen during the lifetime of one process, even if the first one is + // unsuccessful or different file paths are passed here. This only works on + // CUDA. + string xla_gpu_load_autotune_results_from = 258; + + // Relative precision for comparing different GEMM solutions + float xla_gpu_autotune_gemm_rtol = 259; + + // Higher values make it more likely that we'll catch an out-of-bounds read or + // write. Smaller values consume less memory during autotuning. Note that a + // fused cudnn conv has up to 6 total buffers (4 inputs, 1 output, and 1 + // scratch), so this can be multiplied by quite a lot. + int64 xla_gpu_redzone_padding_bytes = 260; + + // Next id: 261 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.