From 60940c53d49550611462a414e4f98b9f88df4a07 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Tue, 15 Oct 2024 11:09:09 -0500 Subject: [PATCH] supported for hipblaslt with autotuning amdhipblaslt_plugin compile gpu_executable compiled gpublas_lt thunk builds unit test compiles more adaptions for workspace buffer and mhlo starting autotuner backport updated picker adding autotuner support autotuner compiles autotuner update remaining autotuner updates remaining build fixes added missing tf32 support gpu_blas_lt_gemm_runner disable check fix location of header files forward gemm calls to gpu blas lt runner minor fix explicit instantiation of ThenBlasGemm use default as fallback algorithm --- build_rocm_python3 | 2 +- .../compiler/xla/debug_options_flags.cc | 71 +- .../xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td | 1 + .../xla/service/computation_placer.cc | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 143 +- .../xla/service/gpu/autotuner_compile_util.cc | 128 ++ .../xla/service/gpu/autotuner_compile_util.h | 98 ++ .../xla/service/gpu/autotuner_util.cc | 247 +++ .../compiler/xla/service/gpu/autotuner_util.h | 177 ++ .../xla/service/gpu/buffer_comparator.cc | 783 ++------- .../xla/service/gpu/buffer_comparator.cu.cc | 184 ++ .../xla/service/gpu/buffer_comparator.h | 36 +- .../compiler/xla/service/gpu/cublas_cudnn.cc | 6 - .../compiler/xla/service/gpu/cublas_cudnn.h | 6 - .../xla/service/gpu/cublas_lt_matmul_thunk.cc | 100 -- .../xla/service/gpu/gemm_algorithm_picker.cc | 779 ++++----- .../xla/service/gpu/gemm_algorithm_picker.h | 92 +- .../compiler/xla/service/gpu/gemm_rewriter.cc | 1481 ++++++++--------- .../compiler/xla/service/gpu/gemm_rewriter.h | 10 +- .../compiler/xla/service/gpu/gemm_thunk.cc | 7 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 41 +- .../service/gpu/gpu_conv_algorithm_picker.cc | 42 +- .../service/gpu/gpu_conv_algorithm_picker.h | 6 +- .../service/gpu/gpu_serializable_autotuner.h | 76 - .../service/gpu/gpublas_lt_matmul_thunk.cc | 185 ++ ...tmul_thunk.h => gpublas_lt_matmul_thunk.h} | 48 +- .../xla/service/gpu/ir_emission_utils.cc | 47 +- .../xla/service/gpu/ir_emission_utils.h | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 21 +- .../xla/service/gpu/ir_emitter_unnested.h | 2 +- .../compiler/xla/service/gpu/matmul_utils.cc | 1027 +++++------- .../compiler/xla/service/gpu/matmul_utils.h | 279 ++-- .../compiler/xla/service/gpu/runtime/BUILD | 7 +- .../compiler/xla/service/gpu/runtime/conv.cc | 9 +- .../service/gpu/runtime/cublas_lt_matmul.cc | 2 +- .../compiler/xla/service/gpu/runtime/gemm.cc | 9 +- .../xla/service/gpu/runtime/support.h | 3 +- .../xla/service/gpu/stream_executor_util.cc | 248 ++- .../xla/service/gpu/stream_executor_util.h | 3 +- .../compiler/xla/service/gpu/tests/BUILD | 15 + .../service/gpu/tests/gpu_hlo_runner_test.cc | 130 ++ tensorflow/compiler/xla/stream_executor/BUILD | 1 + .../compiler/xla/stream_executor/blas.cc | 21 + .../compiler/xla/stream_executor/blas.h | 35 + .../compiler/xla/stream_executor/gpu/BUILD | 106 +- .../xla/stream_executor/gpu/gpu_blas_lt.cc | 285 ++++ .../xla/stream_executor/gpu/gpu_blas_lt.h | 278 ++++ .../gpu/gpu_blas_lt_gemm_runner.cc | 341 ++++ .../gpu/gpu_blas_lt_gemm_runner.h | 260 +++ .../xla/stream_executor/gpu/gpu_driver.h | 2 +- .../xla/stream_executor/gpu/gpu_kernel.h | 6 +- .../stream_executor/gpu/redzone_allocator.cc | 193 +-- .../stream_executor/gpu/redzone_allocator.h | 33 +- .../gpu/redzone_allocator_kernel.h | 39 + .../gpu/redzone_allocator_kernel_cuda.cc | 147 ++ .../gpu/redzone_allocator_kernel_rocm.cu.cc | 50 + .../gpu/redzone_allocator_test.cc | 155 ++ .../xla/stream_executor/kernel_spec.cc | 11 + .../xla/stream_executor/kernel_spec.h | 21 + .../compiler/xla/stream_executor/rocm/BUILD | 78 +- .../xla/stream_executor/rocm/hip_blas_lt.cc | 736 ++++++++ .../stream_executor/rocm/hip_blas_lt.cu.cc | 58 + .../xla/stream_executor/rocm/hip_blas_lt.h | 202 +++ .../stream_executor/rocm/hip_blas_utils.cc | 78 + .../xla/stream_executor/rocm/hip_blas_utils.h | 53 + .../stream_executor/rocm/hipblaslt_wrapper.h | 102 ++ .../xla/stream_executor/rocm/rocm_blas.cc | 38 +- .../xla/stream_executor/rocm/rocm_blas.h | 7 + .../xla/stream_executor/rocm/rocm_driver.cc | 37 +- .../rocm/rocm_driver_wrapper.h | 1 + .../stream_executor/rocm/rocm_gpu_executor.cc | 42 +- .../compiler/xla/stream_executor/stream.cc | 143 ++ .../compiler/xla/stream_executor/stream.h | 70 +- .../stream_executor/stream_executor_pimpl.h | 17 + tensorflow/compiler/xla/tests/matmul_test.cc | 4 - .../mhlo_to_lhlo_with_xla.cc | 21 +- tensorflow/compiler/xla/xla.proto | 28 +- .../bin/crosstool_wrapper_driver_rocm.tpl | 31 +- third_party/gpus/rocm/build_defs.bzl.tpl | 8 + third_party/gpus/rocm/rocm_config.h.tpl | 1 + third_party/gpus/rocm_configure.bzl | 16 + 81 files changed, 6773 insertions(+), 3536 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc create mode 100644 tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h create mode 100644 tensorflow/compiler/xla/service/gpu/autotuner_util.cc create mode 100644 tensorflow/compiler/xla/service/gpu/autotuner_util.h create mode 100644 tensorflow/compiler/xla/service/gpu/buffer_comparator.cu.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h create mode 100644 tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc rename tensorflow/compiler/xla/service/gpu/{cublas_lt_matmul_thunk.h => gpublas_lt_matmul_thunk.h} (52%) create mode 100644 tensorflow/compiler/xla/service/gpu/tests/gpu_hlo_runner_test.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_test.cc create mode 100644 tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc create mode 100644 tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cu.cc create mode 100644 tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h create mode 100644 tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc create mode 100644 tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h create mode 100644 tensorflow/compiler/xla/stream_executor/rocm/hipblaslt_wrapper.h diff --git a/build_rocm_python3 b/build_rocm_python3 index b21888e3d3c480..27eedd2b916313 100755 --- a/build_rocm_python3 +++ b/build_rocm_python3 @@ -26,7 +26,7 @@ done shift "$((OPTIND-1))" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-6.2.0 +ROCM_INSTALL_DIR=$(realpath /opt/rocm) if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 7ee8d906d2e754..632ec2541933f7 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -38,7 +38,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_llvm_enable_invariant_load_metadata(true); opts.set_xla_llvm_disable_expensive_passes(false); opts.set_xla_backend_optimization_level(3); - opts.set_xla_gpu_autotune_level(4); + opts.set_xla_gpu_autotune_level(0); opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_gpu_asm_extra_flags(""); @@ -74,7 +74,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // Note: CublasLt will be used for FP8 GEMMs regardless of the value of this // flag. - opts.set_xla_gpu_enable_cublaslt(false); + opts.set_xla_gpu_enable_cublaslt(true); // TODO(b/258036887): Enable once CUDA Graphs are fully supported. opts.set_xla_gpu_cuda_graph_level(0); @@ -122,7 +122,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_partitioning_algorithm( DebugOptions::PARTITIONING_ALGORITHM_NOOP); - opts.set_xla_gpu_enable_triton_gemm(true); + opts.set_xla_gpu_enable_triton_gemm(false); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); @@ -131,6 +131,19 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(false); opts.set_xla_gpu_collective_inflation_factor(1); + + opts.set_xla_gpu_autotune_gemm_rtol(0.1f); + + opts.set_xla_gpu_redzone_padding_bytes(8 * 1024 * 1024); + + // Minimum combined size of matrices in matrix multiplication to + // be rewritten to cuBLAS or Triton kernel call. + // This threshold is a conservative estimate and has been measured + // to be always beneficial (up to generally several times faster) + // on V100 and H100 GPUs. See openxla/xla #9319 for details. + const int64_t kDefaultMinGemmRewriteSize = 100; + opts.set_xla_gpu_gemm_rewrite_size_threshold(kDefaultMinGemmRewriteSize); + return opts; } @@ -209,6 +222,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, }; }; + auto float_setter_for = + [debug_options](void (DebugOptions::*member_setter)(float)) { + return [debug_options, member_setter](float value) { + (debug_options->*member_setter)(value); + return true; + }; + }; + // Custom "sub-parser" lambda for xla_gpu_shape_checks. auto setter_for_xla_gpu_shape_checks = [debug_options](const std::string& value) { @@ -527,7 +548,35 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), debug_options->xla_gpu_autotune_level(), "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = " - "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check.")); + "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; " + "5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also " + "the related flag xla_gpu_autotune_gemm_rtol. Remark that, setting the " + "level to 5 only makes sense if you are sure that the reference (first " + "in the list) solution is numerically CORRECT. Otherwise, the autotuner " + "might discard many other correct solutions based on the failed " + "BufferComparator test.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_dump_autotune_results_to", + string_setter_for(&DebugOptions::set_xla_gpu_dump_autotune_results_to), + debug_options->xla_gpu_dump_autotune_results_to(), + "File to write autotune results to. It will be a binary file unless the " + "name ends with .txt or .textproto. Warning: The results are written at " + "every compilation, possibly multiple times per process. This only works " + "on CUDA. In tests, the TEST_UNDECLARED_OUTPUTS_DIR prefix can be used " + "to write to their output directory.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_load_autotune_results_from", + string_setter_for(&DebugOptions::set_xla_gpu_load_autotune_results_from), + debug_options->xla_gpu_load_autotune_results_from(), + "File to load autotune results from. It will be considered a binary file " + "unless the name ends with .txt or .textproto. It will be loaded at most " + "once per process. This only works on CUDA. In tests, the TEST_WORKSPACE " + "prefix can be used to load files from their data dependencies.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_autotune_gemm_rtol", + float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol), + debug_options->xla_gpu_autotune_gemm_rtol(), + "Relative precision for comparing GEMM solutions vs the reference one")); flag_list->push_back(tsl::Flag( "xla_force_host_platform_device_count", int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count), @@ -823,6 +872,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_cublaslt), debug_options->xla_gpu_enable_cublaslt(), "Use cuBLASLt for GEMMs when possible.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_gemm_rewrite_size_threshold", + int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold), + debug_options->xla_gpu_gemm_rewrite_size_threshold(), + "Threshold until which elemental dot emitter is preferred for GEMMs " + "(minumum combined number of elements of both matrices " + "in non-batch dimensions to be considered for a rewrite).")); flag_list->push_back(tsl::Flag( "xla_gpu_cuda_graph_level", int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_level), @@ -994,6 +1050,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_triton_gemm_any(), "Use Triton-based matrix multiplication for any GEMM it " "supports without filtering only faster ones.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_redzone_padding_bytes", + int64_setter_for(&DebugOptions::set_xla_gpu_redzone_padding_bytes), + debug_options->xla_gpu_redzone_padding_bytes(), + "Amount of padding the redzone allocator will put on one side of each " + "buffer it allocates. (So the buffer's total size will be increased by " + "2x this value.)")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 3e927de23ab18f..9518ba177a4fbf 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -168,6 +168,7 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS Arg:$d, Arg, "", [MemRead]>:$bias, Arg, "", [MemRead, MemWrite]>:$aux, + Arg, "", [MemRead, MemWrite]>:$workspace, MHLO_DotDimensionNumbers:$dot_dimension_numbers, MHLO_PrecisionConfigAttr:$precision_config, F64Attr:$alpha_real, diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index f00a1399aefec3..29972f2764af8d 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -163,7 +163,7 @@ StatusOr ComputationPlacer::AssignDevices( ComputationPlacerCreationFunction creation_function) { absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); - CHECK(computation_placers->find(platform_id) == computation_placers->end()); + // CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b3b63290749003..a1bcafb9c7364b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -6,7 +6,7 @@ load( "//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gpu_kernel_library") load( "//tensorflow/tsl/platform:build_config_root.bzl", "if_static", @@ -379,7 +379,6 @@ cc_library( ":triangular_solve_thunk", ":cholesky_thunk", ]) + if_cuda_is_configured([ - ":cublas_lt_matmul_thunk", ":ir_emitter_triton", ]), ) @@ -516,7 +515,6 @@ cc_library( ":gpu_executable", ":gpu_float_support", ":gpu_fusible", - ":gpu_serializable_autotuner", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter_triton", @@ -823,6 +821,7 @@ cc_library( "custom_call_thunk.h", "for_thunk.h", "gemm_thunk.h", + "gpublas_lt_matmul_thunk.h", "gpu_executable.h", "infeed_thunk.h", "kernel_thunk.h", @@ -839,6 +838,7 @@ cc_library( ":custom_call_thunk", ":fft_thunk", ":gemm_thunk", + ":gpublas_lt_matmul_thunk", ":gpu_asm_opts_util", ":gpu_constants", ":gpu_conv_runner", @@ -1159,68 +1159,82 @@ cc_library( ) cc_library( - name = "cublas_lt_matmul_thunk", - srcs = if_cuda_is_configured(["cublas_lt_matmul_thunk.cc"]), - hdrs = if_cuda_is_configured(["cublas_lt_matmul_thunk.h"]), - deps = if_cuda_is_configured([ + name = "gpublas_lt_matmul_thunk", + srcs = ["gpublas_lt_matmul_thunk.cc"], + hdrs = ["gpublas_lt_matmul_thunk.h"], + deps = [ ":matmul_utils", ":thunk", - "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla:status", - "//tensorflow/tsl/platform/default/build_config:cublas_plugin", + "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", - "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", - ]) + ["//tensorflow/tsl/platform:logging"], + "//tensorflow/tsl/platform:logging", + ], +) + +cc_library( + name = "autotuner_compile_util", + srcs = if_gpu_is_configured(["autotuner_compile_util.cc"]), + hdrs = if_gpu_is_configured(["autotuner_compile_util.h"]), + deps = if_gpu_is_configured([ + ":autotuner_util", + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream_header", + "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + ]) ) + cc_library( name = "gemm_algorithm_picker", - srcs = if_cuda_is_configured(["gemm_algorithm_picker.cc"]), - hdrs = if_cuda_is_configured(["gemm_algorithm_picker.h"]), - deps = if_cuda_is_configured([ + srcs = ["gemm_algorithm_picker.cc"], + hdrs = ["gemm_algorithm_picker.h"], + deps = [ ":backend_configs_cc", ":buffer_comparator", - ":gemm_thunk", - ":gpu_asm_opts_util", ":gpu_conv_runner", + ":gpu_executable", ":ir_emission_utils", - ":matmul_utils", ":stream_executor_util", - ":gpu_serializable_autotuner", - "@com_google_absl//absl/strings", - "//tensorflow/compiler/xla:autotune_results_proto_cc", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service:hlo_pass", + ":matmul_utils", + ":autotuner_compile_util", + ":autotuner_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor:blas", - "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", - "//tensorflow/compiler/xla/stream_executor:device_memory", - "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", - "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", "//tensorflow/compiler/xla:util", - "//tensorflow/tsl/platform:errors", - "//tensorflow/tsl/platform/default/build_config:cublas_plugin", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/util/proto:proto_utils", "//tensorflow/tsl/platform:logger", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/protobuf:autotuning_proto_cc", - "//tensorflow/tsl/util/proto:proto_utils", - ]), + "//tensorflow/compiler/xla/stream_executor:blas", + "//tensorflow/compiler/xla/stream_executor:device_memory", + "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", + "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", + "@com_google_absl//absl/types:optional", + ], ) cc_library( - name = "gpu_serializable_autotuner", - srcs = [], - hdrs = ["gpu_serializable_autotuner.h"], - deps = [ - "//tensorflow/compiler/xla:autotune_results_proto_cc", + name = "autotuner_util", + srcs = if_gpu_is_configured(["autotuner_util.cc"]), + hdrs = if_gpu_is_configured(["autotuner_util.h"]), + deps = if_gpu_is_configured([ + ":stream_executor_util", + ":gpu_asm_opts_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", "//tensorflow/tsl/protobuf:autotuning_proto_cc", - ], + ]), ) tf_cc_test( @@ -1268,13 +1282,13 @@ cc_library( "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:any", ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", - "//tensorflow/tsl/platform/default/build_config:cublas_plugin", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "//tensorflow/compiler/xla/stream_executor:host_or_device_scalar", "//tensorflow/compiler/xla/stream_executor:scratch_allocator", @@ -1393,9 +1407,9 @@ cc_library( ":gpu_asm_opts_util", ":gpu_autotuning_proto_cc", ":gpu_conv_runner", - ":gpu_serializable_autotuner", ":hlo_algorithm_denylist", ":stream_executor_util", + ":autotuner_util", "//tensorflow/compiler/xla:autotune_results_proto_cc", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", @@ -1411,13 +1425,13 @@ cc_library( "//tensorflow/tsl/platform:numbers", "//tensorflow/tsl/protobuf:autotuning_proto_cc", "//tensorflow/tsl/util/proto:proto_utils", + "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ] + if_cuda_is_configured([ ":buffer_comparator", "@local_config_cuda//cuda:cudnn_header", - "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", ]), ) @@ -2115,7 +2129,6 @@ cc_library( ":gpu_reduce_scatter_creator", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", - ":gpu_serializable_autotuner", ":gpu_shape_verifier", ":hlo_fusion_stats", ":horizontal_input_fusion", @@ -2136,6 +2149,7 @@ cc_library( ":topk_specializer", ":tree_reduction_rewriter", ":variadic_op_splitter", + ":gemm_algorithm_picker", "//tensorflow/compiler/xla:autotune_results_proto_cc", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2258,7 +2272,6 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", ] + if_cuda_is_configured([ - ":gemm_algorithm_picker", ":triton_autotuner", ]), ) @@ -2315,7 +2328,6 @@ cc_library( ":metrics", ":target_constants", ":triangular_solve_rewriter", - ":gpu_serializable_autotuner", "@com_google_absl//absl/base", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings:str_format", @@ -2663,6 +2675,8 @@ cc_library( "//tensorflow/tsl/protobuf:autotuning_proto_cc", "//tensorflow/tsl/util:env_var", "//tensorflow/tsl/util/proto:proto_utils", + #"//tensorflow/core/framework:types_proto_cc", + #"//tensorflow/core/framework:full_type_proto_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2730,22 +2744,37 @@ tf_cc_test( ], ) +tf_gpu_kernel_library( + name = "buffer_comparator_kernel", + hdrs = if_gpu_is_configured(["buffer_comparator.h"]), + srcs = if_gpu_is_configured(["buffer_comparator.cu.cc"]), + deps = [], +) + cc_library( name = "buffer_comparator", - srcs = if_cuda_is_configured(["buffer_comparator.cc"]), - hdrs = if_cuda_is_configured(["buffer_comparator.h"]), - deps = if_cuda_is_configured([ - ":launch_dimensions", - ":gpu_asm_opts_util", - "@com_google_absl//absl/base", - "@com_google_absl//absl/strings", - "//tensorflow/compiler/xla:shape_util", + hdrs = if_gpu_is_configured([ + "buffer_comparator.h" + ]), + srcs = if_gpu_is_configured([ + "buffer_comparator.cc", + ]), + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "-DTENSORFLOW_USE_ROCM=1", + ]), + deps = [ + ":buffer_comparator_kernel", + ":stream_executor_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", - "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", + "@com_google_absl//absl/strings", + ] + if_cuda_is_configured([ + ":partition_assignment", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/stream_executor/cuda:ptxas_utils", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc new file mode 100644 index 00000000000000..8565f1c41f6fd9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -0,0 +1,128 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { +namespace gpu { + +/* static */ StatusOr RedzoneBuffers::FromInstruction( + const HloInstruction& instruction, const AutotuneConfig& config, + se::Stream *stream, const DebugOptions& debug_options, + BuffersToCreate buffers_to_create) { + + std::vector< Shape > input_shapes; + input_shapes.reserve(instruction.operand_count()); + for (const auto* operand : instruction.operands()) { + input_shapes.push_back(operand->shape()); + } + return FromShapes(std::move(input_shapes), instruction.shape(), + config, stream, debug_options, buffers_to_create); +} + +/* static */ StatusOr RedzoneBuffers::FromShapes( + std::vector&& input_shapes, const Shape& output_shape, + const AutotuneConfig& config, se::Stream *stream, + const DebugOptions& debug_options, BuffersToCreate buffers_to_create) { + RedzoneBuffers buffers; + + TF_ASSIGN_OR_RETURN(auto rz_allocator, AutotunerUtil::CreateRedzoneAllocator( + config, stream, debug_options)); + buffers.redzone_allocator_ = + std::make_unique(std::move(rz_allocator)); + + int64_t rng_state = 0; + TF_RETURN_IF_ERROR( + buffers.CreateInputs(std::move(input_shapes), config, rng_state)); + + if (buffers_to_create == BuffersToCreate::kAllInputsAllOutputs || + buffers_to_create == BuffersToCreate::kAllInputsOutputsNoScratch) { + TF_RETURN_IF_ERROR(buffers.CreateOutputs(output_shape, config, + buffers_to_create, rng_state)); + } + return buffers; +} + +Status RedzoneBuffers::CreateInputs(std::vector&& input_shapes, + const AutotuneConfig& config, int64_t& rng_state) { + input_shapes_ = std::move(input_shapes); + for (const auto& shape : input_shapes_) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase buf, + AutotunerUtil::CreateBuffer(*redzone_allocator_, shape, + config, rng_state)); + input_buffers_.push_back(buf); + } + return OkStatus(); +} + +Status RedzoneBuffers::CreateOutputs(const Shape& output_shape, + const AutotuneConfig& config, + BuffersToCreate buffers_to_create, + int64_t& rng_state) { + if (!output_shape.IsTuple()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase buf, + AutotunerUtil::CreateBuffer(*redzone_allocator_, output_shape, + config, rng_state)); + output_buffers_.push_back(buf); + output_shape_ = output_shape; + return OkStatus(); + } + + // The output is a tuple. + + auto current_shape_it = output_shape.tuple_shapes().begin(); + auto end = output_shape.tuple_shapes().end(); + end -= buffers_to_create == kAllInputsAllOutputs ? 0 : 1; + + output_shape_ = std::distance(current_shape_it, end) == 1 + ? *current_shape_it + : ShapeUtil::MakeTupleShape( + std::vector{current_shape_it, end}); + + for (; current_shape_it < end; current_shape_it++) { + if (current_shape_it->IsTuple()) { + return Unimplemented("Nested tuples are unsupported by RedzoneBuffers."); + } + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase buf, + AutotunerUtil::CreateBuffer(*redzone_allocator_, *current_shape_it, + config, rng_state)); + output_buffers_.push_back(buf); + } + return OkStatus(); +} + +} // namespace gpu +} // namespace xla + diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h new file mode 100644 index 00000000000000..25a275906d03af --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { +namespace gpu { + +// A RedZone allocator and a collection of buffers that store the inputs and +// outputs of an HloInstruction. These are used when running the instruction +// for autotuning. +class RedzoneBuffers { + public: + enum BuffersToCreate { + // Create a buffer for all of the instruction's operands. The result shape + // is ignored. + kAllInputs = 0, + // Create a buffer for all of the instruction's operands and the entire + // result shape. If the result shape is a tuple, a separate buffer is + // created for each subshape. + kAllInputsAllOutputs = 1, + // Create a buffer for all of the instruction's operands and all of the + // subshapes of the result tuple, except for the last one. The last subshape + // is considered a scratch buffer and is assumed to be allocated elsewhere. + // If the result shape is not a tuple, this will create a buffer + // corresponding to the entire shape - equivalent to `kAllInputsAllOutputs`. + kAllInputsOutputsNoScratch = 2, + }; + static StatusOr FromInstruction( + const HloInstruction& instruction, const AutotuneConfig& config, + se::Stream *stream, + const DebugOptions& debug_options, BuffersToCreate buffers_to_create); + + static StatusOr FromShapes( + std::vector&& input_shapes, const Shape& output_shape, + const AutotuneConfig& config, se::Stream *stream, + const DebugOptions& debug_options, BuffersToCreate buffers_to_create); + + const std::vector& input_buffers() const { + return input_buffers_; + } + const std::vector& input_shapes() const { return input_shapes_; } + const std::vector& output_buffers() const { + return output_buffers_; + } + const Shape& output_shape() const { return output_shape_; } + se::RedzoneAllocator& RedzoneAllocator() const { return *redzone_allocator_; } + + private: + Status CreateInputs(std::vector&& input_shapes, + const AutotuneConfig& config, + int64_t& rng_state); + + Status CreateOutputs(const Shape& output_shape, + const AutotuneConfig& config, + BuffersToCreate buffers_to_create, + int64_t& rng_state); + + std::unique_ptr redzone_allocator_; + std::vector input_buffers_; + std::vector input_shapes_; + std::vector output_buffers_; + Shape output_shape_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ + diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc new file mode 100644 index 00000000000000..c262f1c9c8bf29 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc @@ -0,0 +1,247 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +//#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" +#include "tensorflow/core/lib/strings/base64.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep + + +#define kCsvSep ',' +#define kCsvComment '#' + + +namespace xla { +namespace gpu { + +using AutotuneCacheMap = absl::flat_hash_map; + +static absl::Mutex autotune_cache_mu(absl::kConstInit); +static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = + *new AutotuneCacheMap(); + +namespace { + +void CSVLegend(std::ostream& os, bool full_string=false) { + + os << kCsvComment << " m" << kCsvSep << "n" << kCsvSep << "k" << kCsvSep + << "batch_count" << kCsvSep << "trans_a" << kCsvSep + << "trans_b" << kCsvSep << "type_a" << kCsvSep << "type_b" << kCsvSep + << "type_c" << kCsvSep << "lda" << kCsvSep << "ldb" << kCsvSep + << "ldc" << kCsvSep << "stride_a" << kCsvSep + << "stride_b" << kCsvSep << "stride_c"; + if (full_string) { + os << kCsvSep << "alpha_re" << kCsvSep << "alpha_im" << kCsvSep + << "beta" << kCsvSep << "epilogue"; + } + os << kCsvSep << "alg_index" << std::endl; +} + +} // namespace + + +/*static*/ auto AutotunerUtil::AddResultToInMemoryCache( + const AutotuneCacheKey& key, CacheValue result, + const AutotuneConfig& cfg) -> const CacheValue& + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + + static std::unique_ptr< std::ofstream > s_dump_fs; + absl::MutexLock lock(&autotune_cache_mu); + auto res = autotune_cache.emplace(key, std::move(result)); + auto it = res.first; + + auto dump_path = cfg.dump_path(); + if (res.second && !dump_path.empty()) { + if (!s_dump_fs) + { + s_dump_fs = std::make_unique< std::ofstream >(std::string(dump_path)); + if (!s_dump_fs->is_open()) { + LOG(WARNING) << "Unable to open: " << dump_path << " for writing!"; + } + CSVLegend(*s_dump_fs, true); + } + *s_dump_fs << key.Get() << kCsvSep << it->second << std::endl; + } + return it->second; +} + +/*static*/ auto AutotunerUtil::TryToFindInInMemoryCache( + const AutotuneCacheKey& key) -> absl::optional + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + absl::MutexLock lock(&autotune_cache_mu); + auto it = autotune_cache.find(key); + if (it == autotune_cache.end()) { + return absl::nullopt; + } + return it->second; +} + +/*static*/ void AutotunerUtil::ClearAutotuneResults() { + absl::MutexLock lock(&autotune_cache_mu); + autotune_cache.clear(); +} + +/* static*/ StatusOr AutotunerUtil::CreateBuffer( + se::RedzoneAllocator& allocator, const Shape& shape, + const AutotuneConfig& config, int64_t& rng_state) { + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, + allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape))); + if (config.should_init_buffers()) { + InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state, + buffer); + } + return buffer; +} + +namespace { +std::string ToCanonicalString(const HloInstruction* instr) { + auto options = HloPrintOptions::Canonical(); + if (instr->opcode() != HloOpcode::kFusion) { + options.set_print_backend_config(true); + return instr->ToString(options); + } + options.set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff); + // options.set_print_infeed_outfeed_config(false); + // options.set_print_only_essential_constants(true); + options.set_print_operand_shape(true); + options.set_print_ids(false); + // options.set_canonicalize_computations(true); + + // TODO(b/266210099): This is unsound. We should probably do the fingerprint + // of the HLO computation proto instead. + return instr->called_computations()[0]->ToString(options); +} + +} // namespace + +/*static*/ auto AutotunerUtil::Autotune( + const std::string& str_key, const AutotuneConfig& cfg, + const AutotuneNoCacheFn& autotune_fn) -> StatusOr { + + AutotuneCacheKey key(str_key); + auto opt_res = TryToFindInInMemoryCache(key); + if (opt_res.has_value()) { + VLOG(1) << "In-memory autotune cache hit: key = " << key.Get(); + return *opt_res; + } + VLOG(1) << "Autotuning for key = " << key.Get() << " needed"; + TF_ASSIGN_OR_RETURN(auto result, autotune_fn()); + return AddResultToInMemoryCache(key, result, cfg); +} + +/*static*/ Status AutotunerUtil::LoadAutotuneResultsFromFileOnce( + const AutotuneConfig& cfg) { + + auto status = OkStatus(); + static absl::once_flag once; + absl::call_once(once, [&cfg, &status] { + status = LoadAutotuneResultsFromFile(cfg); + }); + TF_RETURN_IF_ERROR(status); + return status; +} + +/*static*/ Status AutotunerUtil::LoadAutotuneResultsFromFile( + const AutotuneConfig& cfg) { + + auto file_path = cfg.load_path(); + if (file_path.empty()) return OkStatus(); + + std::ifstream ifs{std::string(file_path)}; + if (!ifs.is_open()) { + LOG(WARNING) << "Unable to open autotune file for reading: " << file_path; + return OkStatus(); + } + + std::vector< std::pair< AutotuneCacheKey, CacheValue >> vec; + vec.reserve(256); + + std::string line; + while(std::getline(ifs, line)) + { + line.erase(0, line.find_first_not_of(" \t\n\r\f\v")); + if (line.empty() || line[0] == '#') continue; + + std::istringstream iss(line); + auto pos = line.find_last_of(kCsvSep); + if (pos == std::string::npos) { + LOG(WARNING) << "Unable to parse CSV row: " << line; + continue; + } + auto key = line.substr(0, pos), sval = line.substr(pos + 1); + char* p_end{}; + auto ival = std::strtol(sval.c_str(), &p_end, 10); + if (p_end == sval.c_str()) { + LOG(WARNING) << "Unable to parse CSV row: " << line; + continue; + } + vec.emplace_back(AutotuneCacheKey{key}, CacheValue{}); + VLOG(1) << "Read autotune cache line: " << key << " -> " << ival; + vec.back().second = ival; + } + for(const auto& p : vec) { + AddResultToInMemoryCache(p.first, p.second, cfg); + } + + LOG(INFO) << "Autotune results loaded from file: " << file_path; + return OkStatus(); +} + +/*static*/ StatusOr +AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, + se::Stream *stream, + const DebugOptions& opts) { + return se::RedzoneAllocator( + stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts), + /*memory_limit=*/std::numeric_limits::max(), + /*redzone_size=*/config.should_check_correctness() + ? opts.xla_gpu_redzone_padding_bytes() + : 0); +} + +/*static*/ AutotuneConvCacheKey AutotunerUtil::ConvCacheKeyFromInstruction( + const HloInstruction* instr, absl::string_view model_str) { + + auto options = HloPrintOptions::Canonical(); + options.set_print_backend_config(true); + return std::make_tuple(std::string(model_str), instr->ToString(options)); +} + +} // namespace gpu +} // namespace xla + diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_util.h new file mode 100644 index 00000000000000..2095ab3226a07d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.h @@ -0,0 +1,177 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ +#define XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/types/variant.h" + +#include "tensorflow/tsl/protobuf/autotuning.pb.h" +//#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { +namespace gpu { + +struct DeviceConfig { + se::StreamExecutor* stream_exec; // never null + + // If the `allocator` parameter is not null, we will use it to allocate temp + // memory while timing the various convolution algorithms. If it's null, + // we'll use the default allocator on the StreamExecutor. + se::DeviceMemoryAllocator* allocator = nullptr; // may be null +}; + +class AutotuneCacheKey { + public: + explicit AutotuneCacheKey(const std::string& s) : key_(s) { } + + absl::string_view Get() const { return key_; } + + template + friend H AbslHashValue(H h, const AutotuneCacheKey& w) { + return H::combine(std::move(h), w.key_); + } + + bool operator==(const AutotuneCacheKey& w) const { + return key_ == w.key_; + } + + private: + std::string key_; +}; + +using AutotuneConvCacheKey = + std::tupleGetDeviceDescription().model_str()*/, + std::string /* instr->ToString(HloPrintOptions::Canonical()) */>; + +// using AutotuneConvCacheMap = +// absl::flat_hash_map; + + +class AutotuneConfig { + public: + bool should_init_buffers() const { return autotune_level_ >= 2; } + bool should_reinit_output_buffer() const { return autotune_level_ >= 3; } + bool should_check_correctness() const { return autotune_level_ >= 4; } + bool should_skip_wrong_results() const { return autotune_level_ >= 5; } + bool should_crash_on_check_failure() const { + return should_crash_on_check_failure_; + } + + absl::string_view dump_path() const { return dump_path_; } + absl::string_view load_path() const { return load_path_; } + + AutotuneConfig(const AutotuneConfig& right) + : config_(right.config_), + autotune_level_(right.autotune_level_), + should_crash_on_check_failure_(right.should_crash_on_check_failure_), + dump_path_(right.dump_path_), + load_path_(right.load_path_) + {} + + AutotuneConfig(const DeviceConfig& config, + const DebugOptions& debug_options) + : config_(config), + autotune_level_(debug_options.xla_gpu_autotune_level()), + should_crash_on_check_failure_( + debug_options.xla_gpu_crash_on_verification_failures()), + dump_path_(debug_options.xla_gpu_dump_autotune_results_to()), + load_path_(debug_options.xla_gpu_load_autotune_results_from()) + {} + + se::StreamExecutor* GetExecutor() const { + CHECK(config_.stream_exec != nullptr); + return config_.stream_exec; + } + + se::DeviceMemoryAllocator* GetAllocator() const { + if (config_.allocator != nullptr) { + return config_.allocator; + } + if (allocator_ == nullptr) { + allocator_ = + std::make_unique(GetExecutor()); + } + return allocator_.get(); + } + + private: + DeviceConfig config_; + int32_t autotune_level_; + bool should_crash_on_check_failure_; + std::string dump_path_, load_path_; + mutable std::unique_ptr allocator_; +}; + +struct AutotunerUtil { + + using CacheValue = int64_t; // algorithm ID + using AutotuneNoCacheFn = std::function()>; + + // Create a buffer for a given operation using redzone checker, initialize + // based on a given rng state. + static StatusOr CreateBuffer( + se::RedzoneAllocator& allocator, const Shape& shape, + const AutotuneConfig& config, int64_t& rng_state); + + static StatusOr Autotune( + const std::string& gemm_config, const AutotuneConfig& config, + const AutotuneNoCacheFn& autotune_fn); + + static absl::optional TryToFindInInMemoryCache( + const AutotuneCacheKey& key); + + static const CacheValue& AddResultToInMemoryCache( + const AutotuneCacheKey& key, CacheValue result, + const AutotuneConfig& cfg); + + // Creates a RedzoneAllocator from a given config. + static StatusOr CreateRedzoneAllocator( + const AutotuneConfig& config, se::Stream *stream, + const DebugOptions& opts); + + // Loads autotune results from a file. + // + // Warning: The results are only loaded to the in-memory cache. + static Status LoadAutotuneResultsFromFile(const AutotuneConfig& config); + // Same as above but do it only once! + static Status LoadAutotuneResultsFromFileOnce(const AutotuneConfig& config); + + // Warning: This only clears the in-memory cache. If you use a file based + // cache you're responsible for clearing the cache directory when you want to. + static void ClearAutotuneResults(); + + static AutotuneConvCacheKey ConvCacheKeyFromInstruction( + const HloInstruction* instr, absl::string_view model_str); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ + diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 836bfbf237959e..59e894c01bdf96 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,675 +17,84 @@ limitations under the License. #include #include -#include +#include +#include +#include +#include -#include "absl/base/call_once.h" -#include "absl/strings/str_replace.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" -#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" #include "tensorflow/compiler/xla/stream_executor/kernel.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { -static constexpr double kTolerance = 0.1f; - -// Comparison kernel code: compare two buffers of -// bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the relative -// error does not exceed the passed rel_error_threshold. Write the number of -// mismatches into out parameter mismatch_count. -// -// NaN's are considered equal, and for half's we clamp all numbers to largest -// and smallest numbers representable to avoid miscomparisons due to overflows. -// -// The PTX below is compiled from the following CUDA code: -// -// #include -// #include -// -// namespace { -// -// __device__ __inline__ float __xla_buffer_comparator_canonicalize(float input) -// { -// // All fp16 infinities are treated as 65505 or -65505, in order to avoid -// // differences due to overflows. -// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); -// } -// -// } // end anonymous namespace -// -// extern "C" { // avoid name mangling -// -// -// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = __half2float(buffer_a[idx]); -// float elem_b = __half2float(buffer_b[idx]); -// elem_a = __xla_buffer_comparator_canonicalize(elem_a); -// elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = buffer_a[idx]; -// float elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) return; -// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) -// return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// -// double elem_a = buffer_a[idx]; -// double elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) return; -// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) -// return; -// double rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_bf16_comparison(__nv_bfloat16* buffer_a, -// __nv_bfloat16* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = __bfloat162float(buffer_a[idx]); -// float elem_b = __bfloat162float(buffer_b[idx]); -// elem_a = __xla_buffer_comparator_canonicalize(elem_a); -// elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// // TODO(b/191520348): The comparison below requires exact equality. -// __global__ void __xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float a = buffer_a[idx]; -// float b = buffer_b[idx]; -// float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// -// __global__ void __xla_int32_comparison(int* buffer_a, int* buffer_b, -// float rel_error_threshold, -// unsigned long long buffer_length, -// int* mismatch_count) { -// int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// float elem_a = static_cast(buffer_a[idx]); -// float elem_b = static_cast(buffer_b[idx]); -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); -// } -// } // end extern declaration -static const char* buffer_compare_ptx = R"( -// -// Generated by LLVM NVPTX Back-End -// - -.version 4.2 -.target sm_30 -.address_size 64 - -// .globl__xla_fp16_comparison - -.visible .entry __xla_fp16_comparison( -.param .u64 __xla_fp16_comparison_param_0, -.param .u64 __xla_fp16_comparison_param_1, -.param .f32 __xla_fp16_comparison_param_2, -.param .u64 __xla_fp16_comparison_param_3, -.param .u64 __xla_fp16_comparison_param_4 -) -{ -.reg .pred %p<10>; -.reg .b16 %rs<3>; -.reg .f32 %f<20>; -.reg .b32 %r<6>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_fp16_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB0_4; -ld.param.u64 %rd5, [__xla_fp16_comparison_param_0]; -ld.param.u64 %rd7, [__xla_fp16_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 1; -add.s64 %rd10, %rd3, %rd9; -ld.global.u16 %rs1, [%rd10]; -// begin inline asm -{ cvt.f32.f16 %f6, %rs1;} - -// end inline asm -add.s64 %rd11, %rd2, %rd9; -ld.global.u16 %rs2, [%rd11]; -// begin inline asm -{ cvt.f32.f16 %f7, %rs2;} - -// end inline asm -abs.f32 %f8, %f6; -setp.gtu.f32 %p2, %f8, 0f7F800000; -min.f32 %f9, %f6, 0f477FE100; -max.f32 %f10, %f9, 0fC77FE100; -selp.f32 %f1, %f6, %f10, %p2; -abs.f32 %f11, %f7; -setp.gtu.f32 %p3, %f11, 0f7F800000; -min.f32 %f12, %f7, 0f477FE100; -max.f32 %f13, %f12, 0fC77FE100; -selp.f32 %f2, %f7, %f13, %p3; -abs.f32 %f3, %f1; -setp.gtu.f32 %p4, %f3, 0f7F800000; -abs.f32 %f4, %f2; -setp.gtu.f32 %p5, %f4, 0f7F800000; -and.pred %p6, %p4, %p5; -@%p6 bra LBB0_4; -ld.param.f32 %f5, [__xla_fp16_comparison_param_2]; -sub.f32 %f14, %f1, %f2; -abs.f32 %f15, %f14; -max.f32 %f16, %f3, %f4; -add.f32 %f17, %f16, 0f3F800000; -div.rn.f32 %f18, %f15, %f17; -setp.gt.f32 %p7, %f18, %f5; -abs.f32 %f19, %f18; -setp.gtu.f32 %p8, %f19, 0f7F800000; -or.pred %p9, %p7, %p8; -@!%p9 bra LBB0_4; -bra.uni LBB0_3; -LBB0_3: -ld.param.u64 %rd6, [__xla_fp16_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r5, [%rd1], 1; -LBB0_4: -ret; - -} -// .globl__xla_fp32_comparison -.visible .entry __xla_fp32_comparison( -.param .u64 __xla_fp32_comparison_param_0, -.param .u64 __xla_fp32_comparison_param_1, -.param .f32 __xla_fp32_comparison_param_2, -.param .u64 __xla_fp32_comparison_param_3, -.param .u64 __xla_fp32_comparison_param_4 -) -{ -.reg .pred %p<12>; -.reg .f32 %f<12>; -.reg .b32 %r<9>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_fp32_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB1_6; -ld.param.u64 %rd5, [__xla_fp32_comparison_param_0]; -ld.param.u64 %rd7, [__xla_fp32_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 2; -add.s64 %rd10, %rd3, %rd9; -ld.global.f32 %f1, [%rd10]; -add.s64 %rd11, %rd2, %rd9; -ld.global.f32 %f2, [%rd11]; -abs.f32 %f3, %f1; -setp.gtu.f32 %p2, %f3, 0f7F800000; -abs.f32 %f4, %f2; -setp.gtu.f32 %p3, %f4, 0f7F800000; -and.pred %p4, %p2, %p3; -@%p4 bra LBB1_6; -setp.eq.f32 %p5, %f3, 0f7F800000; -setp.eq.f32 %p6, %f4, 0f7F800000; -and.pred %p7, %p5, %p6; -@!%p7 bra LBB1_4; -bra.uni LBB1_3; -LBB1_3: -mov.b32 %r5, %f1; -mov.b32 %r6, %f2; -xor.b32 %r7, %r6, %r5; -setp.gt.s32 %p8, %r7, -1; -@%p8 bra LBB1_6; -LBB1_4: -ld.param.f32 %f5, [__xla_fp32_comparison_param_2]; -sub.f32 %f6, %f1, %f2; -abs.f32 %f7, %f6; -max.f32 %f8, %f3, %f4; -add.f32 %f9, %f8, 0f3F800000; -div.rn.f32 %f10, %f7, %f9; -setp.gt.f32 %p9, %f10, %f5; -abs.f32 %f11, %f10; -setp.gtu.f32 %p10, %f11, 0f7F800000; -or.pred %p11, %p9, %p10; -@!%p11 bra LBB1_6; -bra.uni LBB1_5; -LBB1_5: -ld.param.u64 %rd6, [__xla_fp32_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r8, [%rd1], 1; -LBB1_6: -ret; - -} -// .globl__xla_fp64_comparison -.visible .entry __xla_fp64_comparison( -.param .u64 __xla_fp64_comparison_param_0, -.param .u64 __xla_fp64_comparison_param_1, -.param .f32 __xla_fp64_comparison_param_2, -.param .u64 __xla_fp64_comparison_param_3, -.param .u64 __xla_fp64_comparison_param_4 -) -{ -.reg .pred %p<16>; -.reg .f32 %f<2>; -.reg .b32 %r<13>; -.reg .f64 %fd<12>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_fp64_comparison_param_3]; -mov.u32 %r2, %tid.x; -mov.u32 %r3, %ctaid.x; -mov.u32 %r4, %ntid.x; -mad.lo.s32 %r5, %r4, %r3, %r2; -cvt.s64.s32 %rd4, %r5; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB2_6; -ld.param.u64 %rd5, [__xla_fp64_comparison_param_0]; -ld.param.u64 %rd7, [__xla_fp64_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 3; -add.s64 %rd10, %rd3, %rd9; -ld.global.f64 %fd1, [%rd10]; -add.s64 %rd11, %rd2, %rd9; -ld.global.f64 %fd2, [%rd11]; -abs.f64 %fd3, %fd1; -setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000; -abs.f64 %fd4, %fd2; -setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000; -and.pred %p4, %p2, %p3; -@%p4 bra LBB2_6; -{ -.reg .b32 %temp; -mov.b64 {%r6, %temp}, %fd1; -} -{ -.reg .b32 %temp; -mov.b64 {%temp, %r1}, %fd1; -} -and.b32 %r7, %r1, 2147483647; -setp.eq.s32 %p5, %r7, 2146435072; -setp.eq.s32 %p6, %r6, 0; -and.pred %p7, %p5, %p6; -@!%p7 bra LBB2_4; -bra.uni LBB2_3; -LBB2_3: -{ -.reg .b32 %temp; -mov.b64 {%r8, %temp}, %fd2; -} -{ -.reg .b32 %temp; -mov.b64 {%temp, %r9}, %fd2; -} -and.b32 %r10, %r9, 2147483647; -setp.eq.s32 %p8, %r10, 2146435072; -setp.eq.s32 %p9, %r8, 0; -and.pred %p10, %p8, %p9; -xor.b32 %r11, %r9, %r1; -setp.gt.s32 %p11, %r11, -1; -and.pred %p12, %p10, %p11; -@%p12 bra LBB2_6; -LBB2_4: -ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; -sub.f64 %fd5, %fd1, %fd2; -abs.f64 %fd6, %fd5; -max.f64 %fd7, %fd3, %fd4; -add.f64 %fd8, %fd7, 0d3FF0000000000000; -div.rn.f64 %fd9, %fd6, %fd8; -cvt.f64.f32 %fd10, %f1; -setp.gt.f64 %p13, %fd9, %fd10; -abs.f64 %fd11, %fd9; -setp.gtu.f64 %p14, %fd11, 0d7FF0000000000000; -or.pred %p15, %p13, %p14; -@!%p15 bra LBB2_6; -bra.uni LBB2_5; -LBB2_5: -ld.param.u64 %rd6, [__xla_fp64_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r12, [%rd1], 1; -LBB2_6: -ret; - -} -// .globl__xla_bf16_comparison -.visible .entry __xla_bf16_comparison( -.param .u64 __xla_bf16_comparison_param_0, -.param .u64 __xla_bf16_comparison_param_1, -.param .f32 __xla_bf16_comparison_param_2, -.param .u64 __xla_bf16_comparison_param_3, -.param .u64 __xla_bf16_comparison_param_4 -) -{ -.reg .pred %p<10>; -.reg .b16 %rs<3>; -.reg .f32 %f<20>; -.reg .b32 %r<6>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_bf16_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB3_4; -ld.param.u64 %rd5, [__xla_bf16_comparison_param_0]; -ld.param.u64 %rd7, [__xla_bf16_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 1; -add.s64 %rd10, %rd3, %rd9; -ld.global.u16 %rs1, [%rd10]; -// begin inline asm -{ mov.b32 %f6, {0,%rs1};} - -// end inline asm -add.s64 %rd11, %rd2, %rd9; -ld.global.u16 %rs2, [%rd11]; -// begin inline asm -{ mov.b32 %f7, {0,%rs2};} - -// end inline asm -abs.f32 %f8, %f6; -setp.gtu.f32 %p2, %f8, 0f7F800000; -min.f32 %f9, %f6, 0f477FE100; -max.f32 %f10, %f9, 0fC77FE100; -selp.f32 %f1, %f6, %f10, %p2; -abs.f32 %f11, %f7; -setp.gtu.f32 %p3, %f11, 0f7F800000; -min.f32 %f12, %f7, 0f477FE100; -max.f32 %f13, %f12, 0fC77FE100; -selp.f32 %f2, %f7, %f13, %p3; -abs.f32 %f3, %f1; -setp.gtu.f32 %p4, %f3, 0f7F800000; -abs.f32 %f4, %f2; -setp.gtu.f32 %p5, %f4, 0f7F800000; -and.pred %p6, %p4, %p5; -@%p6 bra LBB3_4; -ld.param.f32 %f5, [__xla_bf16_comparison_param_2]; -sub.f32 %f14, %f1, %f2; -abs.f32 %f15, %f14; -max.f32 %f16, %f3, %f4; -add.f32 %f17, %f16, 0f3F800000; -div.rn.f32 %f18, %f15, %f17; -setp.gt.f32 %p7, %f18, %f5; -abs.f32 %f19, %f18; -setp.gtu.f32 %p8, %f19, 0f7F800000; -or.pred %p9, %p7, %p8; -@!%p9 bra LBB3_4; -bra.uni LBB3_3; -LBB3_3: -ld.param.u64 %rd6, [__xla_bf16_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r5, [%rd1], 1; -LBB3_4: -ret; - -} -// .globl__xla_int8_comparison -.visible .entry __xla_int8_comparison( -.param .u64 __xla_int8_comparison_param_0, -.param .u64 __xla_int8_comparison_param_1, -.param .f32 __xla_int8_comparison_param_2, -.param .u64 __xla_int8_comparison_param_3, -.param .u64 __xla_int8_comparison_param_4 -) -{ - .reg .pred %p<5>; - .reg .f32 %f<12>; - .reg .b32 %r<8>; - .reg .b64 %rd<11>; - - ld.param.u64 %rd8, [__xla_int8_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB7_3; - ld.param.f32 %f1, [__xla_int8_comparison_param_2]; - ld.param.u64 %rd5, [__xla_int8_comparison_param_0]; - ld.param.u64 %rd7, [__xla_int8_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - add.s64 %rd9, %rd3, %rd4; - ld.global.s8 %r5, [%rd9]; - add.s64 %rd10, %rd2, %rd4; - ld.global.s8 %r6, [%rd10]; - cvt.rn.f32.s32 %f2, %r5; - cvt.rn.f32.s32 %f3, %r6; - sub.f32 %f4, %f2, %f3; - abs.f32 %f5, %f4; - abs.f32 %f6, %f2; - abs.f32 %f7, %f3; - max.f32 %f8, %f6, %f7; - add.f32 %f9, %f8, 0f3F800000; - div.rn.f32 %f10, %f5, %f9; - setp.leu.f32 %p2, %f10, %f1; - abs.f32 %f11, %f10; - setp.le.f32 %p3, %f11, 0f7F800000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB7_3; - ld.param.u64 %rd6, [__xla_int8_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r7, [%rd1], 1; -LBB7_3: - ret; -} - -// .globl__xla_int32_comparison -.visible .entry __xla_int32_comparison( -.param .u64 __xla_int32_comparison_param_0, -.param .u64 __xla_int32_comparison_param_1, -.param .f32 __xla_int32_comparison_param_2, -.param .u64 __xla_int32_comparison_param_3, -.param .u64 __xla_int32_comparison_param_4 -) -{ -.reg .pred %p<5>; -.reg .f32 %f<12>; -.reg .b32 %r<8>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_int32_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB5_3; -ld.param.f32 %f1, [__xla_int32_comparison_param_2]; -ld.param.u64 %rd5, [__xla_int32_comparison_param_0]; -ld.param.u64 %rd7, [__xla_int32_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 2; -add.s64 %rd10, %rd3, %rd9; -ld.global.u32 %r5, [%rd10]; -cvt.rn.f32.s32 %f2, %r5; -add.s64 %rd11, %rd2, %rd9; -ld.global.u32 %r6, [%rd11]; -cvt.rn.f32.s32 %f3, %r6; -sub.f32 %f4, %f2, %f3; -abs.f32 %f5, %f4; -abs.f32 %f6, %f2; -abs.f32 %f7, %f3; -max.f32 %f8, %f6, %f7; -add.f32 %f9, %f8, 0f3F800000; -div.rn.f32 %f10, %f5, %f9; -setp.gt.f32 %p2, %f10, %f1; -abs.f32 %f11, %f10; -setp.gtu.f32 %p3, %f11, 0f7F800000; -or.pred %p4, %p2, %p3; -@!%p4 bra LBB5_3; -bra.uni LBB5_2; -LBB5_2: -ld.param.u64 %rd6, [__xla_int32_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r7, [%rd1], 1; -LBB5_3: -ret; - -} -)"; - template using ComparisonKernelT = se::TypedKernel, se::DeviceMemory, float, uint64_t, se::DeviceMemory>; +struct ComparisonParams { + double relative_tol = 0.1; + bool verbose = true; + const Shape *shape = nullptr; + se::Stream* stream = nullptr; + se::DeviceMemoryBase current{}; + se::DeviceMemoryBase expected{}; +}; + // Compares two buffers on the GPU. // // Returns `true` if two buffers are equal, `false` otherwise. template -static StatusOr DeviceCompare(se::Stream* stream, - se::DeviceMemoryBase lhs, - se::DeviceMemoryBase rhs, - const Shape& buffer_shape, - const HloModuleConfig& config, - absl::string_view kernel_name) { - se::StreamExecutor* executor = stream->parent(); +static StatusOr DeviceCompare( + absl::string_view kernel_name, void* kernel_symbol, + const ComparisonParams& params) { + se::StreamExecutor* executor = params.stream->parent(); se::ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); - stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); - if (lhs.size() != rhs.size()) { - return InternalError("Mismatched buffer size: %d bytes vs. %d bytes", - lhs.size(), rhs.size()); + params.stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); + if (params.current.size() != params.expected.size()) { + return Internal("Mismatched buffer size: %d bytes vs. %d bytes", + params.current.size(), params.expected.size()); } - se::DeviceMemory lhs_typed(lhs); - se::DeviceMemory rhs_typed(rhs); - uint64_t buffer_size = lhs_typed.ElementCount(); - - absl::Span compiled_ptx = {}; - StatusOr> compiled_ptx_or = - se::CompileGpuAsmOrGetCached( - executor->device_ordinal(), buffer_compare_ptx, - PtxOptsFromDebugOptions(config.debug_options())); - if (compiled_ptx_or.ok()) { - compiled_ptx = std::move(compiled_ptx_or).value(); - } else { - static absl::once_flag ptxas_not_found_logged; - absl::call_once(ptxas_not_found_logged, [&]() { - LOG(WARNING) - << compiled_ptx_or.status().ToString() - << "\nRelying on driver to perform ptx compilation. " - << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda " - << " or modifying $PATH can be used to set the location of ptxas" - << "\nThis message will only be logged once."; - }); - } + se::DeviceMemory current_typed(params.current); + se::DeviceMemory expected_typed(params.expected); + uint64_t buffer_size = current_typed.ElementCount(); TF_ASSIGN_OR_RETURN( std::unique_ptr> comparison_kernel, (executor->CreateTypedKernel, - se::DeviceMemory, float, uint64_t, - se::DeviceMemory>( - kernel_name, buffer_compare_ptx, compiled_ptx))); - - GpuDeviceInfo gpu_device_info; - gpu_device_info.threads_per_block_limit = - executor->GetDeviceDescription().threads_per_block_limit(); - gpu_device_info.threads_per_warp = - executor->GetDeviceDescription().threads_per_warp(); - gpu_device_info.shared_memory_per_block = - executor->GetDeviceDescription().shared_memory_per_block(); - gpu_device_info.threads_per_core_limit = - executor->GetDeviceDescription().threads_per_core_limit(); - gpu_device_info.core_count = executor->GetDeviceDescription().core_count(); - gpu_device_info.block_dim_limit_x = - executor->GetDeviceDescription().block_dim_limit().x; - gpu_device_info.block_dim_limit_y = - executor->GetDeviceDescription().block_dim_limit().y; - gpu_device_info.block_dim_limit_z = - executor->GetDeviceDescription().block_dim_limit().z; - - TF_ASSIGN_OR_RETURN(LaunchDimensions dim, - CalculateLaunchDimensions(buffer_shape, gpu_device_info)); - - LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); - LaunchDimensions::Dim3D block_counts = dim.block_counts(); - TF_RETURN_IF_ERROR(stream->ThenLaunch( - se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), - se::BlockDim(block_counts.x, block_counts.y, block_counts.z), - *comparison_kernel, lhs_typed, rhs_typed, static_cast(kTolerance), - buffer_size, out_param.cref())); + se::DeviceMemory, float, uint64_t, + se::DeviceMemory>( + kernel_name, kernel_symbol))); + + auto gpu_device_info = GetGpuDeviceInfo(executor); + + TF_ASSIGN_OR_RETURN(auto dim, + CalculateLaunchDimensions(*params.shape, gpu_device_info)); + + auto threads = dim.thread_counts_per_block(), + blocks = dim.block_counts(); + params.stream->ThenLaunch(se::ThreadDim(threads.x, threads.y, threads.z), + se::BlockDim(blocks.x, blocks.y, blocks.z), *comparison_kernel, + current_typed, expected_typed, static_cast(params.relative_tol), + buffer_size, out_param.cref()); uint64_t result = -1; CHECK_EQ(out_param->size(), sizeof(result)); - stream->ThenMemcpy(&result, *out_param, sizeof(result)); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + params.stream->ThenMemcpy(&result, *out_param, sizeof(result)); + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); return result == 0; } @@ -694,13 +103,13 @@ static StatusOr DeviceCompare(se::Stream* stream, // // Returns true if no differences were seen, false otherwise. template -StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs, - se::DeviceMemoryBase rhs) { - int64_t n = lhs.size() / sizeof(ElementType); - std::vector host_lhs(n), host_rhs(n); - stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size()); - stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size()); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); +static StatusOr HostCompare(const ComparisonParams& params) { + int64_t n = params.current.size() / sizeof(ElementType); + std::vector host_current(n), host_expected(n); + + params.stream->ThenMemcpy(host_current.data(), params.current, params.current.size()); + params.stream->ThenMemcpy(host_expected.data(), params.expected, params.expected.size()); + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); const auto canonicalize = [](ComparisonType a) -> ComparisonType { if (std::is_same::value && a) { @@ -713,82 +122,89 @@ StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs, return a; }; int differences_seen = 0; - for (int64_t i = 0; i < n && differences_seen < 10; i++) { - auto original_lhs = static_cast(host_lhs[i]); - auto original_rhs = static_cast(host_rhs[i]); - ComparisonType lhs = canonicalize(original_lhs); - ComparisonType rhs = canonicalize(original_rhs); - if (std::isnan(lhs) && std::isnan(rhs)) { + + for (int64_t i = 0; i < n && differences_seen < 10; ++i) { + auto current_value = static_cast(host_current[i]); + auto expected_value = static_cast(host_expected[i]); + ComparisonType current_value_canonical = canonicalize(current_value); + ComparisonType expected_value_canonical = canonicalize(expected_value); + if (std::isnan(current_value_canonical) && + std::isnan(expected_value_canonical)) { continue; } - if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) { + if (std::isinf(current_value_canonical) && + std::isinf(expected_value_canonical) && + current_value_canonical == expected_value_canonical) { continue; } - if (std::isfinite(lhs) != std::isfinite(rhs) || - !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) < - kTolerance)) { - differences_seen++; - LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs " - << original_rhs; + if (std::isfinite(current_value_canonical) != + std::isfinite(expected_value_canonical) || + !(std::abs(current_value_canonical - expected_value_canonical) / + (std::max(std::abs(current_value_canonical), + std::abs(expected_value_canonical)) + + 1) < params.relative_tol)) { + if(!params.verbose) return false; // Return immediately if not verbose. + ++differences_seen; + LOG(ERROR) << "Difference at " << i << ": " << current_value + << ", expected " << expected_value; } } return differences_seen == 0; } template -static StatusOr CompareEqualParameterized(se::Stream* stream, - se::DeviceMemoryBase lhs, - se::DeviceMemoryBase rhs, - const Shape& shape, - const HloModuleConfig& config, - absl::string_view kernel_name) { +static StatusOr CompareEqualParameterized( + absl::string_view kernel_name, void* kernel_symbol, + const ComparisonParams& params) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); TF_ASSIGN_OR_RETURN( - bool result, - DeviceCompare(stream, lhs, rhs, shape, config, kernel_name)); + bool result, DeviceCompare(kernel_name, kernel_symbol, params)); if (result) { return true; } - TF_ASSIGN_OR_RETURN(bool host_return, - (HostCompare(stream, lhs, rhs))); + TF_ASSIGN_OR_RETURN(bool host_return, + (HostCompare(params))); CHECK_EQ(host_return, result) << "Host comparison succeeded even though GPU comparison failed."; - return false; } -StatusOr BufferComparator::CompareEqual(se::Stream* stream, - se::DeviceMemoryBase lhs, - se::DeviceMemoryBase rhs) const { +StatusOr BufferComparator::CompareEqual( + se::Stream* stream, se::DeviceMemoryBase current, + se::DeviceMemoryBase expected) const { + + ComparisonParams params{ + relative_tol_, verbose_, &shape_, stream, current, expected}; + switch (shape_.element_type()) { case xla::F16: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison"); + "fp16_comparison", buffer_comparator::fp16_comparison(), params); case xla::BF16: - return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_bf16_comparison"); + return CompareEqualParameterized( + "bf16_comparison", buffer_comparator::bf16_comparison(), params); case xla::F32: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison"); + "fp32_comparison", buffer_comparator::fp32_comparison(), params); case xla::F64: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison"); + "fp64_comparison", buffer_comparator::fp64_comparison(), params); case xla::S8: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_int8_comparison"); + "int8_comparison", buffer_comparator::int8_comparison(), params); case xla::S32: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_int32_comparison"); + "int32_comparison", buffer_comparator::int32_comparison(), params); default: return Unimplemented("Unimplemented element type"); } } -BufferComparator::BufferComparator(const Shape& shape, - const HloModuleConfig& config) - : shape_(shape), config_(config) { +BufferComparator::BufferComparator(const Shape& shape, double tolerance, + bool verbose) : + shape_(shape), relative_tol_(tolerance), verbose_(verbose) { // Normalize complex shapes: since we treat the passed array as a contiguous // storage it does not matter which dimension are we doubling. auto double_dim_size = [&]() { @@ -809,3 +225,4 @@ BufferComparator::BufferComparator(const Shape& shape, } // namespace gpu } // namespace xla + diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cu.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cu.cc new file mode 100644 index 00000000000000..fed21e7e56c6ef --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cu.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#include +#include +#include + +using bfloat16 = __nv_bfloat16; +#define BF16_TO_F32 __bfloat162float + +#elif TENSORFLOW_USE_ROCM +#include +#include + +#include "rocm/rocm_config.h" + +using bfloat16 = hip_bfloat16; +#define BF16_TO_F32 float + +#endif + +#include + +namespace xla { +namespace gpu { +namespace buffer_comparator { + +// Comparison kernel code: compare two buffers of +// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the +// relative error does not exceed the passed rel_error_threshold. Write the +// number of mismatches into out parameter mismatch_count. + +// NaN's are considered equal, and for half's we clamp all numbers to largest +// and smallest numbers representable to avoid miscomparisons due to overflows. +namespace { + +__device__ __inline__ float Canonicalize(float input) { + // All fp16 infinities are treated as 65505 or -65505, in order to avoid + // differences due to overflows. + return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); +} + +__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = __half2float(buffer_a[idx]); + float elem_b = __half2float(buffer_b[idx]); + elem_a = Canonicalize(elem_a); + elem_b = Canonicalize(elem_b); + if (isnan(elem_a) && isnan(elem_b)) return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_fp32_comparison(float* buffer_a, float* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = buffer_a[idx]; + float elem_b = buffer_b[idx]; + if (isnan(elem_a) && isnan(elem_b)) return; + if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) + return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_fp64_comparison(double* buffer_a, double* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + + double elem_a = buffer_a[idx]; + double elem_b = buffer_b[idx]; + if (isnan(elem_a) && isnan(elem_b)) return; + if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) + return; + double rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_bf16_comparison(bfloat16* buffer_a, bfloat16* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = BF16_TO_F32(buffer_a[idx]); + float elem_b = BF16_TO_F32(buffer_b[idx]); + elem_a = Canonicalize(elem_a); + elem_b = Canonicalize(elem_b); + if (isnan(elem_a) && isnan(elem_b)) return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +// TODO(b/191520348): The comparison below requires exact equality. +__global__ void xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float a = buffer_a[idx]; + float b = buffer_b[idx]; + float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_int32_comparison(int* buffer_a, int* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + float elem_a = static_cast(buffer_a[idx]); + float elem_b = static_cast(buffer_b[idx]); + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +} // namespace + +void* fp16_comparison() { + return reinterpret_cast(&xla_fp16_comparison); +} + +void* bf16_comparison() { + return reinterpret_cast(&xla_bf16_comparison); +} + +void* fp32_comparison() { + return reinterpret_cast(&xla_fp32_comparison); +} + +void* fp64_comparison() { + return reinterpret_cast(&xla_fp64_comparison); +} + +void* int8_comparison() { + return reinterpret_cast(&xla_int8_comparison); +} + +void* int32_comparison() { + return reinterpret_cast(&xla_int32_comparison); +} + +} // namespace buffer_comparator +} // namespace gpu +} // namespace xla + + + diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h index 8be4decbd055bc..0bb356e071638a 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -18,10 +18,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/shape.h" + +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" -namespace xla { -namespace gpu { + +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif + +namespace xla::gpu { // A device-side comparator that compares buffers. class BufferComparator { @@ -29,7 +35,8 @@ class BufferComparator { BufferComparator(const BufferComparator&) = delete; BufferComparator(BufferComparator&&) = default; - BufferComparator(const Shape& shape, const HloModuleConfig& config); + BufferComparator(const Shape& shape, double tolerance = 0.1, + bool verbose = true); // Returns true if the two buffers compare equal. The definition of "equal" // is: @@ -40,15 +47,28 @@ class BufferComparator { // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance // // See the implementation for the tolerance value. - StatusOr CompareEqual(se::Stream* stream, se::DeviceMemoryBase lhs, - se::DeviceMemoryBase rhs) const; + StatusOr CompareEqual(se::Stream* stream, + se::DeviceMemoryBase current, + se::DeviceMemoryBase expected) const; private: Shape shape_; - HloModuleConfig config_; + double relative_tol_; // relative tolerance for comparison + bool verbose_; // whether to print out error message on mismatch }; +namespace buffer_comparator { + +// Returns a pointer to CUDA C++ device function implementing comparison. +void* fp16_comparison(); +void* bf16_comparison(); +void* fp32_comparison(); +void* fp64_comparison(); +void* int8_comparison(); +void* int32_comparison(); + +} // namespace buffer_comparator +} // namespace xla::gpu + -} // namespace gpu -} // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc index 6168ca9458b580..c588252ea2d0b5 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc +++ b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc @@ -34,14 +34,8 @@ bool IsCublasLtMatmul(const HloInstruction& hlo) { hlo.custom_call_target() == kCublasLtMatmulCallTarget; } -bool IsCublasLtMatmulF8(const HloInstruction& hlo) { - return hlo.opcode() == HloOpcode::kCustomCall && - hlo.custom_call_target() == kCublasLtMatmulF8CallTarget; -} - const absl::string_view kGemmCallTarget = "__cublas$gemm"; const absl::string_view kCublasLtMatmulCallTarget = "__cublas$lt$matmul"; -const absl::string_view kCublasLtMatmulF8CallTarget = "__cublas$lt$matmul$f8"; const absl::string_view kTriangularSolveCallTarget = "__cublas$triangularSolve"; const absl::string_view kCudnnConvBackwardInputCallTarget = diff --git a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h index 3a9fded79035ed..5ac806c9ca9f70 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h +++ b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h @@ -61,18 +61,12 @@ bool IsLegacyCublasMatmul(const HloInstruction& hlo); // Matrix multiplication that calls into cublasLt. bool IsCublasLtMatmul(const HloInstruction& hlo); -// Scaled matrix multiplication in FP8. Calls into cublasLt. -bool IsCublasLtMatmulF8(const HloInstruction& hlo); - // A call to cuBLAS general matrix multiplication API. extern const absl::string_view kGemmCallTarget; // A call to cuBLAS Lt API matrix multiplication. extern const absl::string_view kCublasLtMatmulCallTarget; -// A call to cuBLASLt for scaled matrix multiplication in FP8. -extern const absl::string_view kCublasLtMatmulF8CallTarget; - // A call to cuBLAS for a triangular solve. // // Like cudnn convolutions, this op returns a tuple (result, scratch_memory). diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc deleted file mode 100644 index 881f40c52232d8..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h" - -#include - -#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" -#include "tensorflow/compiler/xla/service/gpu/thunk.h" -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" -#include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/tsl/platform/logging.h" - -namespace xla { -namespace gpu { - -CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, cublas_lt::MatmulPlan plan, int64_t algorithm_idx, - BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer, BufferAllocation::Slice aux_buffer, - BufferAllocation::Slice a_scale, BufferAllocation::Slice b_scale, - BufferAllocation::Slice c_scale, BufferAllocation::Slice d_scale, - BufferAllocation::Slice d_amax) - : Thunk(Kind::kCublasLtMatmul, thunk_info), - plan_(std::move(plan)), - algorithm_idx_(algorithm_idx), - a_buffer_(a_buffer), - b_buffer_(b_buffer), - c_buffer_(c_buffer), - d_buffer_(d_buffer), - bias_buffer_(bias_buffer), - aux_buffer_(aux_buffer), - a_scale_buffer_(a_scale), - b_scale_buffer_(b_scale), - c_scale_buffer_(c_scale), - d_scale_buffer_(d_scale), - d_amax_buffer_(d_amax) {} - -Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { - if (!algorithm_) { - TF_ASSIGN_OR_RETURN( - std::vector algorithms, - plan_.GetAlgorithms(params.stream)); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - algorithm_ = algorithms[algorithm_idx_]; - } - - VLOG(3) << "Running cublas_lt matmul thunk"; - const BufferAllocations& allocs = *params.buffer_allocations; - - se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; - if (bias_buffer_.allocation() != nullptr) { - bias = allocs.GetDeviceAddress(bias_buffer_); - } - if (a_scale_buffer_.allocation() != nullptr) { - a_scale = allocs.GetDeviceAddress(a_scale_buffer_); - } - if (b_scale_buffer_.allocation() != nullptr) { - b_scale = allocs.GetDeviceAddress(b_scale_buffer_); - } - if (c_scale_buffer_.allocation() != nullptr) { - c_scale = allocs.GetDeviceAddress(c_scale_buffer_); - } - if (d_scale_buffer_.allocation() != nullptr) { - d_scale = allocs.GetDeviceAddress(d_scale_buffer_); - } - if (d_amax_buffer_.allocation() != nullptr) { - d_amax = allocs.GetDeviceAddress(d_amax_buffer_); - } - - se::DeviceMemoryBase aux; - if (aux_buffer_.allocation() != nullptr) { - aux = allocs.GetDeviceAddress(aux_buffer_); - } - - se::OwningScratchAllocator<> scratch_allocator(allocs.device_ordinal(), - allocs.memory_allocator()); - return plan_.ExecuteOnStream( - params.stream, allocs.GetDeviceAddress(a_buffer_), - allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), - allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, *algorithm_, scratch_allocator); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index d6936053cf1892..11d27e84191655 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,492 +15,383 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" -#include -#include -#include +#include +#include +#include #include #include -#include -#include #include #include #include #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/tsl/platform/errors.h" -#include "tensorflow/tsl/platform/logger.h" -#include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/protobuf/autotuning.pb.h" #include "tensorflow/tsl/util/proto/proto_utils.h" -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" -#endif - namespace xla { namespace gpu { +namespace { -using tensorflow::AutotuneResult; - -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -se::RedzoneAllocator CreateRedzoneAllocator( - se::Stream* stream, se::DeviceMemoryAllocator* allocator, - const DebugOptions& debug_options, const AutotuneConfig& config) { - int64_t redzone_size = config.should_check_correctness() - ? se::RedzoneAllocator::kDefaultRedzoneSize - : 0; - - return se::RedzoneAllocator( - stream, allocator, PtxOptsFromDebugOptions(debug_options), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/redzone_size); -} -#endif +using se::gpu::BlasLt; + +class GemmAutotuner { + const AutotuneConfig& autotune_config_; + RedzoneBuffers rz_buffers_; + std::unique_ptr< se::Stream > stream_; + bool deterministic_ops_ = false; + float gemm_relative_tol_ = 0.1f; + + public: + explicit GemmAutotuner(const AutotuneConfig& autotune_config) + : autotune_config_(autotune_config) {} + + StatusOr operator()( + const GemmConfig& gemm_config, + std::vector< Shape >&& input_shapes, const Shape& output_shape, + const DebugOptions& debug_options) { + + VLOG(3) << "Starting autotune of GemmThunk standalone"; + + if(!stream_) { + stream_ = std::make_unique< se::Stream >(autotune_config_.GetExecutor()); + stream_->Init(); + } -// Returns the index (into `algorithms`) of the fastest algorithm. -template -StatusOr> GetBestAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, const Shape& output_shape, - const HloModuleConfig& hlo_module_config, double beta, - const std::function(const AlgoT&)>& - run_benchmark) { - if (!stream->parent()->SynchronizeAllActivity()) { - return InternalError("Failed to synchronize GPU for autotuning."); + deterministic_ops_ = false ; + gemm_relative_tol_ = debug_options.xla_gpu_autotune_gemm_rtol(); + + // Don't run autotuning concurrently on the same GPU. + absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent())); + + TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromShapes( + std::move(input_shapes), output_shape, autotune_config_, stream_.get(), + debug_options, RedzoneBuffers::kAllInputsAllOutputs)); + + return TuneGpuBlasLt(output_shape, gemm_config); } - se::DeviceMemoryBase reference_buffer; - if (autotune_config.should_check_correctness()) { - TF_ASSIGN_OR_RETURN( - reference_buffer, - allocator.AllocateBytes(ShapeUtil::ByteSizeOf(output_shape))); - } + StatusOr operator()(const HloInstruction* gemm, + const GemmConfig& gemm_config) { - BufferComparator comparator(output_shape, hlo_module_config); + VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - std::vector results; - std::optional reference_algorithm; - - for (const AlgoT& algorithm : algorithms) { - // Make sure the output buffer always has the same value if we use - // the bias parameter. - if (autotune_config.should_reinit_output_buffer() && beta != 0) { - int64_t rng_state = 0; - InitializeBuffer(stream, output_shape.element_type(), &rng_state, - output_buffer); + if(!stream_) { + stream_ = std::make_unique< se::Stream >(autotune_config_.GetExecutor()); + stream_->Init(); } - TF_ASSIGN_OR_RETURN(se::blas::ProfileResult profile_result, - run_benchmark(algorithm)); - - results.emplace_back(); - AutotuneResult& result = results.back(); - result.mutable_gemm()->set_algorithm(profile_result.algorithm()); + const DebugOptions& debug_options = + gemm->GetModule()->config().debug_options(); + deterministic_ops_ = false ; + gemm_relative_tol_ = debug_options.xla_gpu_autotune_gemm_rtol(); - if (!profile_result.is_valid()) { // Unsupported algorithm. - result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); - continue; - } + // Don't run autotuning concurrently on the same GPU. + absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent())); + TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromInstruction( + *gemm, autotune_config_, stream_.get(), debug_options, + RedzoneBuffers::kAllInputsAllOutputs)); - VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took " - << profile_result.elapsed_time_in_ms() << "ms"; + return IsCublasLtMatmul(*gemm) + ? TuneGpuBlasLt(gemm->shape(), gemm_config) + : TuneGpuBlas(gemm->shape(), gemm_config); + } - *result.mutable_run_time() = tsl::proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); + private: + se::DeviceMemoryBase LhsBuffer() { return rz_buffers_.input_buffers().at(0); } + se::DeviceMemoryBase RhsBuffer() { return rz_buffers_.input_buffers().at(1); } + se::DeviceMemoryBase OutputBuffer() { + return rz_buffers_.output_buffers().at(0); + } - if (!autotune_config.should_check_correctness()) { - continue; + StatusOr TuneGpuBlasLt(const Shape& out_shape, const GemmConfig& gemm_config) { + + se::DeviceMemoryBase workspace_buffer; + if(out_shape.IsTuple()) { + workspace_buffer = rz_buffers_.output_buffers(). + at(out_shape.tuple_shapes_size() - 1); } - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones()); - - if (!rz_check_status.ok()) { - result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); - *result.mutable_failure()->mutable_msg() = - rz_check_status.RedzoneFailureMsg(); - LOG(ERROR) << "Detected out-of-bounds write in gemm buffer"; - CHECK(!autotune_config.should_crash_on_check_failure); - continue; - } + bool has_matrix_bias = gemm_config.beta != 0.; + bool has_vector_bias = ((int)gemm_config.epilogue & (int)BlasLt::Epilogue::kBias) != 0; + bool has_aux_output = (gemm_config.epilogue == BlasLt::Epilogue::kGELUWithAux || + gemm_config.epilogue == BlasLt::Epilogue::kBiasThenGELUWithAux); - if (!reference_algorithm) { - stream->ThenMemcpy(&reference_buffer, output_buffer, - output_buffer.size()); - reference_algorithm = profile_result.algorithm(); - } else { - // Perform the comparison. - TF_ASSIGN_OR_RETURN( - bool outputs_match, - comparator.CompareEqual(stream, output_buffer, reference_buffer)); - if (!outputs_match) { - LOG(ERROR) << "Results mismatch between different GEMM algorithms. " - << "This is likely a bug/unexpected loss of precision."; - CHECK(!autotune_config.should_crash_on_check_failure); + se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, + d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer; - result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT); - result.mutable_failure()->mutable_reference_gemm()->set_algorithm( - *reference_algorithm); - } + if (has_vector_bias) { + bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2); } - } - - if (!autotune_config.should_crash_on_check_failure) { - tensorflow::AutotuningLog log; - for (const AutotuneResult& result : results) { - *log.add_results() = result; + if (has_aux_output) { + aux_buffer = rz_buffers_.output_buffers().at(1); } - tsl::Logger::GetSingleton()->LogProto(log); - } + + TF_ASSIGN_OR_RETURN(auto plan, + BlasLt::GetMatmulPlan(stream_.get(), gemm_config)); - StatusOr best = - PickBestResult(results, gemm_str, hlo_module_config); - if (best.ok()) { - for (size_t i = 0; i < results.size(); ++i) { - if (best->gemm().algorithm() == results[i].gemm().algorithm()) { - return {i}; - } - } - return InternalError("unknown best algorithm"); + TF_ASSIGN_OR_RETURN( + auto algorithms, + plan->GetAlgorithms(/*max_algorithm_count*/ se::gpu::BlasLt::kMaxAlgorithms, + /*max_workspace_size*/ workspace_buffer.size())); + + auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm) + -> StatusOr { + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithm)); + // Run a warmup iteration without the profiler active. + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_.get(), LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(), + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, + workspace_buffer)); + + se::blas::ProfileResult profile_result; + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_.get(), LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(), + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, + workspace_buffer, absl::nullopt, &profile_result)); + return std::move(profile_result); + }; + + const auto& shape = out_shape.IsTuple() ? out_shape.tuple_shapes(0) + : out_shape; + return GetBestAlgorithm( + shape, algorithms, gemm_config.beta, false, tuned_func); } - LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance " - "might be suboptimal: " - << best.status(); - return {std::nullopt}; -} - -StatusOr> GetBestBlasAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, - const Shape& output_shape, const HloModuleConfig& hlo_module_config, - double beta, - const std::function( - const se::blas::AlgorithmType&)>& run_benchmark) { - TF_ASSIGN_OR_RETURN( - std::optional result, - GetBestAlgorithm( - stream, allocator, gemm_str, autotune_config, lhs_buffer, rhs_buffer, - output_buffer, algorithms, output_shape, hlo_module_config, beta, - run_benchmark)); - return result; -} - -namespace { + StatusOr TuneGpuBlas(const Shape& out_shape, + const GemmConfig& gemm_config) { +#if 0 + auto workspace_buffer = rz_buffers_.output_buffers().at(1); -StatusOr AsBlasLtEpilogue( - GemmBackendConfig_Epilogue epilogue) { - switch (epilogue) { - case GemmBackendConfig::DEFAULT: - return se::cuda::BlasLt::Epilogue::kDefault; - case GemmBackendConfig::RELU: - return se::cuda::BlasLt::Epilogue::kReLU; - case GemmBackendConfig::GELU: - return se::cuda::BlasLt::Epilogue::kGELU; - case GemmBackendConfig::GELU_AUX: - return se::cuda::BlasLt::Epilogue::kGELUWithAux; - case GemmBackendConfig::BIAS: - return se::cuda::BlasLt::Epilogue::kBias; - case GemmBackendConfig::BIAS_RELU: - return se::cuda::BlasLt::Epilogue::kBiasThenReLU; - case GemmBackendConfig::BIAS_GELU: - return se::cuda::BlasLt::Epilogue::kBiasThenGELU; - case GemmBackendConfig::BIAS_GELU_AUX: - return se::cuda::BlasLt::Epilogue::kBiasThenGELUWithAux; - default: - return InternalError("Unsupported Epilogue."); - } -} + std::vector algorithms; + TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc, + gemm_config.GetMatrixDescriptors( + LhsBuffer(), RhsBuffer(), OutputBuffer())); -StatusOr CreateBuffer(se::RedzoneAllocator& allocator, - const Shape& shape, - const AutotuneConfig& config, - int64_t& rng_state) { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, - allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape))); - if (config.should_init_buffers()) { - InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state, - buffer); + auto blas = stream_->parent()->AsBlas(); + if (blas == nullptr) { + return xla::InternalError("No BLAS support for stream"); + } + blas->GetBlasGemmAlgorithms(stream_.get(), desc.lhs, desc.rhs, &desc.output, + &gemm_config.alpha, &gemm_config.beta, + &algorithms); + + auto tuned_func = [&](const se::blas::AlgorithmType& algorithm) + -> StatusOr { + // Do a warm-up run first, without a profile result. RunGemm swallows + // error codes when profile_result is passed, as it is in the measurement + // below, but not otherwise. It is, therefore, consistent to ignore the + // error code here. + static_cast(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(), + OutputBuffer(), workspace_buffer, + deterministic_ops_, stream_.get(), algorithm)); + se::blas::ProfileResult profile_result; + // Allow GpuTimer to use its delay kernel implementation to improve + // accuracy. + profile_result.set_warmup_run_executed(true); + // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail + // for all algorithms if we're targeting < sm_50. But because we pass a + // non-null ProfileResult, DoGemmWithAlgorithm should always return true, + // and the actual success-ness is returned in ProfileResult::is_valid. + TF_RETURN_IF_ERROR(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(), + OutputBuffer(), workspace_buffer, + deterministic_ops_, stream_.get(), algorithm, + &profile_result)); + return std::move(profile_result); + }; + + const auto& shape = out_shape.IsTuple() ? out_shape.tuple_shapes(0) + : out_shape; + return GetBestAlgorithm( + shape, algorithms, gemm_config.beta, false, tuned_func); +#else + return tensorflow::AutotuneResult{}; +#endif } - return buffer; -} -StatusOr CreateBuffer(se::RedzoneAllocator& allocator, - const HloInstruction& op, - const AutotuneConfig& config, - int64_t& rng_state) { - return CreateBuffer(allocator, op.shape(), config, rng_state); -} + // Returns the index (into `algorithms`) of the fastest algorithm. + template + StatusOr GetBestAlgorithm( + const Shape& output_shape, absl::Span algorithms, + double beta, bool return_algo_index, TunedFunc&& run_benchmark) { -static absl::Mutex autotune_cache_mu(absl::kConstInit); -static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = - *new absl::flat_hash_map>(); -static int64_t autotune_cache_hits ABSL_GUARDED_BY(autotune_cache_mu) = 0; -static int64_t autotune_cache_misses ABSL_GUARDED_BY(autotune_cache_mu) = 0; - -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) - -StatusOr> DoGemmAutotune( - const HloInstruction* gemm, const GemmBackendConfig& gemm_config, - se::DeviceMemoryAllocator* allocator, se::Stream* stream) { - VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - - auto key = AutotuneCacheKeyFromInstruction( - gemm, stream->parent()->GetDeviceDescription().model_str()); - - { - absl::MutexLock lock(&autotune_cache_mu); - auto it = autotune_cache.find(key); - int64_t requests = autotune_cache_hits + autotune_cache_misses; - if (requests && requests % 10 == 0) { - VLOG(2) << "Autotuning cache hits/(hits + misses): " - << autotune_cache_hits << "/" << requests; + if (!stream_->parent()->SynchronizeAllActivity()) { + return Internal("Failed to synchronize GPU for autotuning."); } - if (it != autotune_cache.end()) { - autotune_cache_hits++; - VLOG(4) << "Autotuning cache hit, using algorithm: " - << (it->second.has_value() ? absl::StrCat(*(it->second)) - : ""); - return it->second; + se::DeviceMemoryBase reference_buffer; + if (autotune_config_.should_check_correctness()) { + TF_ASSIGN_OR_RETURN(reference_buffer, + rz_buffers_.RedzoneAllocator().AllocateBytes( + ShapeUtil::ByteSizeOf(output_shape))); } - VLOG(4) << "Autotuning cache miss"; - autotune_cache_misses++; - } - - const DebugOptions& debug_options = - gemm->GetModule()->config().debug_options(); - AutotuneConfig autotune_config = GetConfig(debug_options); - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); - // Don't run autotuning concurrently on the same GPU. - absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent())); - - se::RedzoneAllocator buffer_allocator = - CreateRedzoneAllocator(stream, allocator, debug_options, autotune_config); - - int64_t rng_state = 0; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer, - CreateBuffer(buffer_allocator, *gemm->operand(0), - autotune_config, rng_state)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer, - CreateBuffer(buffer_allocator, *gemm->operand(1), - autotune_config, rng_state)); - - const Shape& output_shape = - gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); - - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase output_buffer, - CreateBuffer(buffer_allocator, output_shape, autotune_config, rng_state)); + // Do not print error messages if should_skip_wrong_results() is ON. + BufferComparator comparator(output_shape, gemm_relative_tol_, + /* verbose */!autotune_config_.should_skip_wrong_results() + ); + std::vector results; + results.reserve(algorithms.size()); + absl::optional reference_algorithm; + + for (size_t i = 0; i < algorithms.size(); i++) { + const AlgoT& algorithm = algorithms[i]; + // Make sure the output buffer always has the same value if we use + // the bias parameter. + if (autotune_config_.should_reinit_output_buffer() && beta != 0) { + int64_t rng_state = 0; + InitializeBuffer(stream_.get(), output_shape.element_type(), &rng_state, + OutputBuffer()); + } + TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm)); - HloModuleConfig& hlo_module_config = gemm->GetModule()->config(); - std::optional best_algorithm; - if (IsCublasLtMatmul(*gemm)) { - bool has_matrix_bias = config.beta != 0.; + results.emplace_back(); + tensorflow::AutotuneResult& result = results.back(); + result.mutable_gemm()->set_algorithm(profile_result.algorithm()); - TF_ASSIGN_OR_RETURN(bool has_vector_bias, cublas_lt::EpilogueAddsVectorBias( - gemm_config.epilogue())); + if (!profile_result.is_valid()) { // Unsupported algorithm. + result.mutable_failure()->set_kind(tensorflow::AutotuneResult::DISQUALIFIED); + continue; + } - TF_ASSIGN_OR_RETURN( - bool has_aux_output, - cublas_lt::EpilogueHasAuxiliaryOutput(gemm_config.epilogue())); + VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took " + << profile_result.elapsed_time_in_ms() << "ms"; - TF_ASSIGN_OR_RETURN(auto epilogue, - AsBlasLtEpilogue(gemm_config.epilogue())); + *result.mutable_run_time() = tsl::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); - se::DeviceMemoryBase bias_buffer; - if (has_vector_bias) { - TF_ASSIGN_OR_RETURN(bias_buffer, - CreateBuffer(buffer_allocator, - *gemm->operand(has_matrix_bias ? 3 : 2), - autotune_config, rng_state)); - } - se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, - d_scale_buffer, d_amax_buffer; + if (!autotune_config_.should_check_correctness()) { + continue; + } + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, + rz_buffers_.RedzoneAllocator().CheckRedzones()); + + if (!rz_check_status.ok()) { + result.mutable_failure()->set_kind(tensorflow::AutotuneResult::REDZONE_MODIFIED); + *result.mutable_failure()->mutable_msg() = + rz_check_status.RedzoneFailureMsg(); + LOG(ERROR) << "Detected out-of-bounds write in gemm buffer"; + CHECK(!autotune_config_.should_crash_on_check_failure()); + continue; + } - se::DeviceMemoryBase aux_buffer; - if (has_aux_output) { + if (!reference_algorithm) { + stream_->ThenMemcpy(&reference_buffer, OutputBuffer(), + OutputBuffer().size()); + reference_algorithm = profile_result.algorithm(); + continue; + } + // Perform the comparison versus the reference algorithm. TF_ASSIGN_OR_RETURN( - aux_buffer, - CreateBuffer(buffer_allocator, gemm->shape().tuple_shapes(1), - autotune_config, rng_state)); + bool outputs_match, + comparator.CompareEqual(stream_.get(), /*current=*/OutputBuffer(), + /*expected=*/reference_buffer)); + if (!outputs_match) { + LOG(ERROR) << "Results mismatch between different GEMM algorithms. " + << "This is likely a bug/unexpected loss of precision."; + CHECK(!autotune_config_.should_crash_on_check_failure()); + + // By default, autotuner does NOT really skip wrong results, but + // merely prints out the above error message: this may lead to a + // great confusion. When should_skip_wrong_results() is set to true, + // solutions with accuracy problems will be disqualified. + auto kind = tensorflow::AutotuneResult::WRONG_RESULT; + if (autotune_config_.should_skip_wrong_results()) { + kind = tensorflow::AutotuneResult::DISQUALIFIED; + } + result.mutable_failure()->set_kind(kind); + result.mutable_failure()->mutable_reference_gemm()->set_algorithm( + *reference_algorithm); + } + } // for algorithms + + StatusOr best_res = + PickBestResult(results, absl::nullopt); + if (best_res.ok()) { + auto best = std::move(best_res.value()); + // Return a real algorithm ID if return_algo_index is false: + // e.g., in case of legacy cublas tuning. + if (!return_algo_index) return best; + // Otherwise, map a real algorithm ID to its index among the results. + for (size_t i = 0; i < results.size(); ++i) { + if (best.gemm().algorithm() == results[i].gemm().algorithm()) { + best.mutable_gemm()->set_algorithm(i); + return best; + } + } + return Internal("unknown best algorithm"); } - - TF_ASSIGN_OR_RETURN(auto plan, - cublas_lt::MatmulPlan::From(config, epilogue)); - TF_ASSIGN_OR_RETURN( - std::vector algorithms, - plan.GetAlgorithms(stream)); - - TF_ASSIGN_OR_RETURN( - std::optional best_algorithm_idx, - GetBestAlgorithm( - stream, buffer_allocator, gemm->ToString(), autotune_config, - lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), - [&](const se::cuda::BlasLt::MatmulAlgorithm& algorithm) - -> StatusOr { - se::OwningScratchAllocator<> scratch_allocator( - stream->parent()->device_ordinal(), allocator); - se::blas::ProfileResult profile_result; - TF_RETURN_IF_ERROR(plan.ExecuteOnStream( - stream, lhs_buffer, rhs_buffer, output_buffer, output_buffer, - bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, - c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, - scratch_allocator, &profile_result)); - return std::move(profile_result); - })); - - TF_RET_CHECK(best_algorithm_idx) << "failed to auto-tune cublas_lt matmul"; - best_algorithm = *best_algorithm_idx; - } else { - std::vector algorithms; - TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)); - - TF_ASSIGN_OR_RETURN(std::optional best_algorithm_idx, - GetBestBlasAlgorithm( - stream, buffer_allocator, gemm->ToString(), - autotune_config, lhs_buffer, rhs_buffer, - output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), - [&](const se::blas::AlgorithmType& algorithm) - -> StatusOr { - se::blas::ProfileResult profile_result; - // We expect GemmWithAlgorithm to fail sometimes - // -- in fact, it will fail for all algorithms if - // we're targeting < sm_50. But because we pass a - // non-null ProfileResult, DoGemmWithAlgorithm - // should always return true, and the actual - // success-ness is returned in - // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm( - config, lhs_buffer, rhs_buffer, output_buffer, - stream, algorithm, &profile_result)); - return std::move(profile_result); - })); - - if (best_algorithm_idx) best_algorithm = algorithms[*best_algorithm_idx]; - } - - // Insert our result into the cache. After we released the lock on - // autotune_cache_mu, another autotuning job may have run for this same key on - // another GPU on the machine. If so, use its result. - absl::MutexLock lock(&autotune_cache_mu); - auto [it, inserted] = autotune_cache.emplace(key, best_algorithm); - return it->second; -} - -StatusOr RunOnInstruction(HloInstruction* instr, DeviceConfig config) { - se::StreamExecutor* executor = config.stream_exec; - se::DeviceMemoryAllocator* allocator = config.allocator; - if (allocator == nullptr) { - allocator = executor->GetAllocator(); - } - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(executor->device_ordinal())); - - GemmBackendConfig gemm_config = - instr->backend_config().value(); - - TF_ASSIGN_OR_RETURN(std::optional gemm_algorithm, - DoGemmAutotune(instr, gemm_config, allocator, stream)); - - // We update instruction->backend_config(); if no algorithms are supported, - // a different API is used, which does not require specifying an algorithm. - GemmBackendConfig updated_config = gemm_config; - - // We only set the 'algorithm' field on non-Ampere architectures, as for - // Ampere it's ignored in any case. - if (gemm_algorithm && - !executor->GetDeviceDescription().cuda_compute_capability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - VLOG(4) << "GEMM autotuning picked algorithm " << *gemm_algorithm << " for " - << instr->name(); - updated_config.set_selected_algorithm(*gemm_algorithm); - } - TF_RETURN_IF_ERROR(instr->set_backend_config(updated_config)); - return updated_config.SerializeAsString() != gemm_config.SerializeAsString(); -} - -#endif + LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance " + "might be suboptimal: " + << best_res.status(); + return tensorflow::AutotuneResult{}; + } // GetBestAlgorithm +}; // GemmAutotuner // Do Gemm Autotune without stream executor. Use results from autotune cache // only. -StatusOr RunOnInstruction(HloInstruction* gemm, DevicelessConfig config) { +StatusOr RunOnInstruction(HloInstruction* gemm, + const AutotuneConfig& config) { VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString(); + TF_ASSIGN_OR_RETURN(auto backend_config, + gemm->backend_config()); - auto key = AutotuneCacheKeyFromInstruction(gemm, config.model_str); - - // Load selected algorithm from the autotune cache. - std::optional algorithm; - { - absl::MutexLock lock(&autotune_cache_mu); - if (auto it = autotune_cache.find(key); it != autotune_cache.end()) { - VLOG(4) << "AOT autotuning cache hit, using algorithm: " - << (it->second.has_value() ? absl::StrCat(*(it->second)) - : ""); - algorithm = it->second; - } - VLOG(4) << "AOT autotuning cache miss"; + // Degenerate gemms replaced with memzero operation, no need to auto tune it. + if (backend_config.alpha_real() == 0.0 && + backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { + VLOG(3) << "Skip degenerate gemm instruction auto tuning"; + return false; } - se::CudaComputeCapability capability = config.cuda_compute_capability; - GemmBackendConfig gemm_config = - gemm->backend_config().value(); - GemmBackendConfig updated_config = gemm_config; - - // We only set the 'algorithm' field on non-Ampere architectures, as for - // Ampere it's ignored in any case. - if (!capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - if (algorithm) { - updated_config.set_selected_algorithm(*algorithm); - } else { - updated_config.set_selected_algorithm(se::blas::kRuntimeAutotuning); - } + TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm)); + + GemmAutotuner autotuner(config); + TF_ASSIGN_OR_RETURN(auto new_algorithm, + AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config, + [&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm, gemm_config)); + return algo.has_gemm() ? algo.gemm().algorithm() : se::blas::kDefaultAlgorithm; + })); + + auto old_algorithm = backend_config.selected_algorithm(); + if (new_algorithm == old_algorithm) { + // We don't need to update the backend config if + // the algorithm hasn't changed unless previously + // the algorithm wasn't set explicitly. + return false; } - TF_RETURN_IF_ERROR(gemm->set_backend_config(updated_config)); - return updated_config.SerializeAsString() != gemm_config.SerializeAsString(); + + backend_config.set_selected_algorithm(new_algorithm); + TF_RETURN_IF_ERROR(gemm->set_backend_config(backend_config)); + return true; // We changed `gemm` } StatusOr RunOnComputation(HloComputation* computation, - AutotuningConfig config) { + AutotuneConfig config) { bool changed = false; + for (HloInstruction* instr : computation->instructions()) { - if (IsCublasGemm(*instr)) { - bool result; - if (auto device_config = std::get_if(&config)) { -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) - TF_ASSIGN_OR_RETURN(result, RunOnInstruction(instr, *device_config)); -#else - LOG(FATAL) << "GPU-enabled build is required to run autotuning"; -#endif - } else { - TF_ASSIGN_OR_RETURN( - result, - RunOnInstruction(instr, std::get(config))); - } + //if (IsCublasGemm(*instr)) { + if (IsCublasLtMatmul(*instr)) { // NOTE: legacy cublas autotuning is NYI ! + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, config)); changed |= result; } } @@ -509,51 +400,24 @@ StatusOr RunOnComputation(HloComputation* computation, } // namespace -void GemmAlgorithmPicker::ClearAutotuneResults() { - absl::MutexLock lock(&autotune_cache_mu); - autotune_cache.clear(); -} - -Status GemmAlgorithmPicker::WriteAutotuneResults(AutotuneResults* results) { - absl::MutexLock lock(&autotune_cache_mu); +StatusOr GemmAlgorithmPicker::RunStandalone( + const se::gpu::GemmConfig& cfg, + std::vector< Shape >&& input_shapes, const Shape& output_shape, + const DebugOptions& debug_options) { - for (const auto& [k, result] : autotune_cache) { - // For now, we don't cache "failed to autotune" results, because we don't - // have a good way to represent them in the proto. - if (!result.has_value()) continue; + GemmAutotuner autotuner(config_); + GemmConfig gemm_config{cfg}; - const auto& [model_str, hlo] = k; - auto& entry = *results->add_dots(); - entry.set_device(model_str); - entry.set_hlo(hlo); - entry.mutable_result()->mutable_gemm()->set_algorithm(*result); - } - - // Sort the results so they're deterministic. - std::sort(results->mutable_dots()->pointer_begin(), - results->mutable_dots()->pointer_end(), - [](const auto* a, const auto* b) { - return std::make_pair(absl::string_view(a->device()), - absl::string_view(a->hlo())) < - std::make_pair(absl::string_view(b->device()), - absl::string_view(b->hlo())); - }); - return OkStatus(); -} - -Status GemmAlgorithmPicker::LoadAutotuneResults( - const AutotuneResults& results) { - absl::MutexLock lock(&autotune_cache_mu); - for (const auto& result : results.dots()) { - autotune_cache[std::make_tuple(result.device(), result.hlo())] = - result.result().gemm().algorithm(); - } - return OkStatus(); + return AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config_, + [&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm_config, std::move(input_shapes), + output_shape, debug_options)); + return algo.has_gemm() ? algo.gemm().algorithm() : se::blas::kDefaultAlgorithm; + }); } -StatusOr GemmAlgorithmPicker::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { +StatusOr GemmAlgorithmPicker::Run(HloModule* module, + const absl::flat_hash_set& threads) { XLA_SCOPED_LOGGING_TIMER( absl::StrCat("GemmAlgorithmPicker for ", module->name())); @@ -564,7 +428,7 @@ StatusOr GemmAlgorithmPicker::Run( bool changed = false; for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { + module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_)); changed |= result; } @@ -573,3 +437,4 @@ StatusOr GemmAlgorithmPicker::Run( } // namespace gpu } // namespace xla + diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h index e87de621b08185..ac2beceaf8d625 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h @@ -12,67 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ +#ifndef XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ +#define XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ #include #include -#include #include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/autotune_results.pb.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +//#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" + +#include "tensorflow/tsl/protobuf/autotuning.pb.h" + +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h" + +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" -#include "tensorflow/tsl/protobuf/autotuning.pb.h" - -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" -#endif +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" -namespace xla { +namespace stream_executor { namespace gpu { + struct GemmConfig; +} -struct AutotuneConfig { - bool should_init_buffers() const { return autotune_level >= 2; } - bool should_reinit_output_buffer() const { return autotune_level >= 3; } - bool should_check_correctness() const { return autotune_level >= 4; } - - int32_t autotune_level; - bool should_crash_on_check_failure; -}; - -static AutotuneConfig GetConfig(const DebugOptions& debug_options) { - return {debug_options.xla_gpu_autotune_level(), - debug_options.xla_gpu_crash_on_verification_failures()}; } -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -se::RedzoneAllocator CreateRedzoneAllocator( - se::Stream* stream, se::DeviceMemoryAllocator* allocator, - const DebugOptions& debug_options, const AutotuneConfig& config); -#endif - -// Select the best algorithm using information from a Blas instruction. -// Returns the index (into `algorithms`) of the fastest algorithm. -StatusOr> GetBestBlasAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, - const Shape& output_shape, const HloModuleConfig& hlo_module_config, - double beta, - const std::function( - const se::blas::AlgorithmType&)>& run_benchmark); +namespace xla { +namespace gpu { + // GemmAlgorithmPicker supports two modes: device and deviceless. // In device mode, we run autotuning on the device and store autotune results. @@ -81,24 +55,28 @@ StatusOr> GetBestBlasAlgorithm( // autotune result is not stored, then algorithm is set to kRuntimeAutotuning. class GemmAlgorithmPicker : public HloModulePass { public: - static void ClearAutotuneResults(); - static Status WriteAutotuneResults(AutotuneResults* results); - static Status LoadAutotuneResults(const AutotuneResults& results); - - explicit GemmAlgorithmPicker(AutotuningConfig config) : config_(config) {} + explicit GemmAlgorithmPicker(AutotuneConfig config): config_(config) {} absl::string_view name() const override { return "gemm-algorithm-picker"; } + const AutotuneConfig& config() const { + return config_; + } + using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; + StatusOr Run(HloModule* module, + const absl::flat_hash_set& threads) override; + + StatusOr RunStandalone( + const se::gpu::GemmConfig& gemm_config, + std::vector< Shape >&& input_shapes, const Shape& output_shape, + const DebugOptions& debug_options); private: - AutotuningConfig config_; + AutotuneConfig config_; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ +#endif // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index cac47a341c87ba..94b4585c4212b4 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/protobuf/dnn.pb.h" + namespace xla { namespace gpu { namespace { @@ -59,18 +60,18 @@ namespace m = match; Status SetName(HloModule *module, HloInstruction *gemm) { if (IsCublasLtMatmul(*gemm)) { module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul"); - return OkStatus(); + return absl::OkStatus(); } - GemmBackendConfig config; - TF_ASSIGN_OR_RETURN(config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto config, + gemm->backend_config()); const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() || !dot_dims.rhs_batch_dimensions().empty(); module->SetAndUniquifyInstrName( gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm"); - return OkStatus(); + return absl::OkStatus(); } // Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue @@ -97,103 +98,77 @@ bool IsF8Type(const HloInstruction *instr) { return primitive_util::IsF8Type(instr->shape().element_type()); } -// Recursively collects unary, pad, divide or multiply operands of instr until -// an instruction with FP8 element type is reached. Returns std::nullopt when no -// FP8 instruction is reached. -std::optional> FindF8SubgraphRecursive( - HloInstruction *instr, absl::flat_hash_set &visited_instrs, - std::vector subgraph) { - // Avoid visiting the same instruction more than once. - if (!visited_instrs.emplace(instr->unique_id()).second) { - return std::nullopt; - } - subgraph.emplace_back(instr); - if (IsF8Type(instr)) { - return subgraph; - } else { - if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide || - instr->opcode() == HloOpcode::kPad) { - return FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs, - subgraph); - } else if (instr->opcode() == HloOpcode::kMultiply) { - for (int k = 0; k < 2; ++k) { - auto mult_subgraph = FindF8SubgraphRecursive(instr->mutable_operand(k), - visited_instrs, subgraph); - if (mult_subgraph.has_value()) { - return mult_subgraph; - } - } +// Returns a new shape with non-batch dimensions padded to multiples of 16, as +// required by cuBLASLt FP8 gemms. +Shape PadShapeToMultipleOf16(const Shape old_shape, + const absl::Span batch_dims) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); } - return std::nullopt; } + return padded_shape; } -// Returns whether instr and its operands describe a pattern which is compatible -// with rewriting the dot operating on instr into an FP8 Custom Call. If -// applicable, captures the operand of the Custom Call, its scaling factor, -// whether the scaling factor is applied by multiplication and intermediate -// unary ops. -bool IsSupportedF8Pattern(HloInstruction *instr, HloInstruction *&x, - HloInstruction *&x_scale, bool &x_mult_scale, - std::vector &x_unary_ops) { - absl::flat_hash_set visited_instrs; - std::optional> subgraph = - FindF8SubgraphRecursive(instr, visited_instrs, - std::vector{}); - - if (!subgraph.has_value()) { - return false; +// Pad the dimensions of the operands to the target shape. +HloInstruction *PadOperandToTargetShape(const Shape &target, + HloInstruction *x) { + if (ShapeUtil::Equal(target, x->shape()) || + !ShapeUtil::SameElementType(x->shape(), target)) { + return x; } - std::reverse(subgraph->begin(), subgraph->end()); - // Directly operating on an FP8 operand. - if (subgraph->size() == 1) { - x = (*subgraph)[0]; - return true; + PaddingConfig padding_config; + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(target.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); } - // When not operating directly on an FP8 operand, the second and - // third instructions in the subgraph must describe a dequantization, i.e. a - // convert instruction followed by a multiply/divide instruction. - if (subgraph->size() > 2 && - Match((*subgraph)[2], - m::MultiplyAnyOrder(m::Convert(m::Op(&x)), - m::Broadcast(m::Op(&x_scale))))) { - x_mult_scale = true; - } else if (subgraph->size() > 2 && - Match((*subgraph)[2], m::Divide(m::Convert(m::Op(&x)), - m::Broadcast(m::Op(&x_scale))))) { - x_mult_scale = false; - } else { - VLOG(1) << "Possible intended FP8 GEMM operating on " - << instr->ToShortString() << " not rewritten into FP8 Custom Call."; - return false; - } + HloInstruction *zero = x->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + return x->AddInstruction( + HloInstruction::CreatePad(target, x, zero, padding_config)); +} - auto preserves_element_type = [](const HloInstruction *instr) -> bool { - return ShapeUtil::SameElementType(instr->shape(), - instr->operand(0)->shape()); - }; - for (int i = 3; i < subgraph->size(); ++i) { - // The remaining instructions must be commutative with dequantization. - // Bitcast, broadcast, copy, pad, reshape and slice instructions are - // supported. - if (!Match((*subgraph)[i], - m::AnyOf( - m::Bitcast().WithPredicate(preserves_element_type), - m::Broadcast(), m::Copy(), m::Pad(), m::Reshape(), - m::Slice()))) { - VLOG(1) << "Possible intended FP8 GEMM operating on " - << instr->ToShortString() - << " not rewritten into FP8 Custom Call."; - return false; - } +// Pad the non-batch dimensions of the operands to multiples of 16 as required +// by cuBLASLt FP8 gemms. +HloInstruction *PadOperandToMultipleOf16(absl::Span batch_dims, + HloInstruction *x) { + Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims); + return PadOperandToTargetShape(padded_shape, x); +} + +// Calculates the reciprocal of scalar when invert is true and converts to FP32. +StatusOr InvertAndConvertScalar(HloInstruction *scalar, + bool invert) { + DCHECK(ShapeUtil::IsScalar(scalar->shape())); + + if (invert) { + Literal one_literal = LiteralUtil::One(scalar->shape().element_type()); + HloInstruction *one = scalar->parent()->AddInstruction( + HloInstruction::CreateConstant(one_literal.Clone())); + TF_ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar, + &scalar->metadata())); + } + if (scalar->shape().element_type() != F32) { + scalar = MakeConvertToHlo(scalar, F32, &scalar->metadata()); } - x_unary_ops = {subgraph->begin() + 3, subgraph->end()}; - return true; + return scalar; } +// A path of instructions by traversing downwards through users, as (op, +// operand_index) pairs. operand_index is the index to get to the previous +// element in the path. I.e., +// path[i].first->operand(path[i].second) == path[i-1].first +using InstrPath = std::vector>; + // Transposes a matrix by swapping the contracting and non-contracting // dimension. There must be only one contracting and only one non-contracting // dimension. Keeps the layout the same. @@ -222,7 +197,6 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, HloInstruction::CreateTranspose(new_shape, instr, permutation)); } - // If the bias is a sequence of ops that depend only on broadcasts of // constants, materialize the bias if it's small. // @@ -286,16 +260,6 @@ auto GemmOrCublasLtMatmul(HloInstruction **instr) { return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget}); } -auto CublasLtMatmulMaybeF8(HloInstruction **instr) { - return m::CustomCall( - instr, {kCublasLtMatmulCallTarget, kCublasLtMatmulF8CallTarget}); -} - -auto GemmOrCublasLtMatmulMaybeF8(HloInstruction **instr) { - return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget, - kCublasLtMatmulF8CallTarget}); -} - auto BcastConstScalar(HloInstruction **instr, double value) { return m::Broadcast(instr, m::ConstantScalar(value)); } @@ -308,11 +272,27 @@ auto BcastConstScalarNear(double value) { // Not a very robust floating-point comparison, but good enough for our // purposes. std::optional actual = - static_cast(instr) + xla::Cast(instr) ->literal() .GetAsDouble({}); if (!actual.has_value()) return false; - double epsilon = 128 * std::numeric_limits::epsilon(); + double epsilon; + switch (instr->shape().element_type()) { + case F16: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + case BF16: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + case F32: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + case F64: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + default: + return false; + } return abs(*actual - expected) < (abs(*actual + expected) * epsilon); })); } @@ -329,6 +309,12 @@ auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) { std::move(pattern)); } +template +auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) { + return m::AnyOf(m::Bitcast(optional_bitcast, pattern), + std::move(pattern)); +} + // The rewriting proceeds in a bottom-up way: // // (kDot A B) is rewritten into a (kCustomCall:gemm A B) @@ -363,19 +349,36 @@ auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) { // when the output of the GEMM is requested in FP8 format. class GemmRewriterVisitor : public DfsHloRewriteVisitor { public: - explicit GemmRewriterVisitor( - GpuVersion gpu_version) + explicit GemmRewriterVisitor(const GpuVersion &gpu_version) : gpu_version_(gpu_version) {} Status HandleDot(HloInstruction *instr) override { - if (!IsMatrixMultiplication(*instr)) { - return OkStatus(); + if (!IsMatrixMultiplication(*instr) && + !IsMatrixVectorMultiplication(*instr)) { + return absl::OkStatus(); + } + // Sparse dot is not supported. + // if (Cast(instr)->sparse_operands()) { + // return absl::OkStatus(); + // } + + int64_t gemm_rewrite_size_threshold = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_gemm_rewrite_size_threshold(); + TF_ASSIGN_OR_RETURN(bool is_matmul_tiny, + IsMatrixMultiplicationTooSmallForRewriting( + *instr, gemm_rewrite_size_threshold)); + if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) { + return absl::OkStatus(); } CHECK(!instr->IsRank2Transpose()); - CHECK(!instr->mutable_operand(0)->IsRank2Transpose()); - CHECK(!instr->mutable_operand(1)->IsRank2Transpose()); - + if (instr->operand(0)->IsRank2Transpose() || + instr->operand(1)->IsRank2Transpose()) { + return absl::OkStatus(); + } // Create a GemmBackendConfig based on the instruction. GemmBackendConfig gemm_backend_config; gemm_backend_config.set_alpha_real(1.0); @@ -392,78 +395,52 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm_backend_config.set_grad_y(attributes["grad_y"] == "true"); int64_t lhs_batch_dims_size = - instr->dot_dimension_numbers().lhs_batch_dimensions_size(); - int64_t lhs_stride = lhs->shape().dimensions(lhs_batch_dims_size) * - lhs->shape().dimensions(lhs_batch_dims_size + 1); - int64_t rhs_stride = rhs->shape().dimensions(lhs_batch_dims_size) * - rhs->shape().dimensions(lhs_batch_dims_size + 1); + instr->dot_dimension_numbers().lhs_batch_dimensions_size(); + bool is_lhs_vector = + lhs->shape().dimensions_size() == lhs_batch_dims_size + 1; + bool is_rhs_vector = + rhs->shape().dimensions_size() == lhs_batch_dims_size + 1; + int64_t lhs_stride = + is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size) + : lhs->shape().dimensions(lhs_batch_dims_size) * + lhs->shape().dimensions(lhs_batch_dims_size + 1); + int64_t rhs_stride = + is_rhs_vector ? rhs->shape().dimensions(lhs_batch_dims_size) + : rhs->shape().dimensions(lhs_batch_dims_size) * + rhs->shape().dimensions(lhs_batch_dims_size + 1); gemm_backend_config.set_lhs_stride(lhs_stride); gemm_backend_config.set_rhs_stride(rhs_stride); - // First try to match the fp8 gemm pattern. - TF_ASSIGN_OR_RETURN(bool supported_by_cublaslt, - GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); - HloInstruction *a, *b, *a_scale = nullptr, *b_scale = nullptr; - std::vector a_unary_ops, b_unary_ops; - bool a_mult_scale, b_mult_scale; - if (supported_by_cublaslt && - Match(instr, - m::Dot(m::Op().WithPredicate([&](const HloInstruction *instr) { - return IsSupportedF8Pattern(const_cast(instr), - a, a_scale, a_mult_scale, - a_unary_ops); - }), - m::Op().WithPredicate([&](const HloInstruction *instr) { - return IsSupportedF8Pattern( - const_cast(instr), b, b_scale, - b_mult_scale, b_unary_ops); - })))) { + { + // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. TF_ASSIGN_OR_RETURN( - bool created_call, - CreateF8CustomCall(instr, gemm_backend_config, a, b, a_scale, b_scale, - a_mult_scale, b_mult_scale, a_unary_ops, - b_unary_ops)); - if (created_call) { - return OkStatus(); - } + absl::string_view gemm_custom_call_target, + GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); + const Shape &output_shape = instr->shape(); + HloInstruction *gemm_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {instr->mutable_operand(0), instr->mutable_operand(1)}, + gemm_custom_call_target)); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); } - - if (IsF8Type(instr->operand(0))) { - // Couldn't rewrite as an FP8 cublasLt custom call, so turn into an FP16 - // dot and below it will be rewritten as an FP16 cublas or cublasLt call. - TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); - } - - // Couldn't rewrite as an FP8 cublasLt custom call, rewrite as a cublas or - // cublasLt call. - TF_ASSIGN_OR_RETURN( - absl::string_view gemm_custom_call_target, - GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); - const Shape &output_shape = instr->shape(); - HloInstruction *gemm_call = - instr->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); - return OkStatus(); + return absl::OkStatus(); } Status HandleMultiply(HloInstruction *instr) override { HloInstruction *alpha, *existing_gemm; if (Match(instr, m::MultiplyAnyOrder( - GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), + GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(), m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) { TF_ASSIGN_OR_RETURN(auto config, existing_gemm->backend_config()); - // Do not fuse alpha into S32 GEMM, as they only support fixed values for // alpha/beta. if (existing_gemm->shape().element_type() == S32) { - return OkStatus(); + return absl::OkStatus(); } if (config.beta() == 0.0 && existing_gemm->user_count() == 1) { @@ -481,70 +458,94 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // (https://arxiv.org/abs/1606.08415), where: // approx_gelu(x) = x * cdf(x) // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3)) - HloInstruction *cdf; - if (Match(instr, m::MultiplyAnyOrder(CublasLtMatmul(&existing_gemm), - m::Op(&cdf).WithOneUser())) && + HloInstruction *cdf, *slice_or_bitcast = nullptr; + if (Match(instr, m::MultiplyAnyOrder( + m::AnyOf( + m::Slice(&slice_or_bitcast, + CublasLtMatmul(&existing_gemm)), + m::Bitcast(&slice_or_bitcast, + CublasLtMatmul(&existing_gemm)), + CublasLtMatmul(&existing_gemm)), + m::Op(&cdf).WithOneUser())) && Match(cdf, m::MultiplyAnyOrder( BcastConstScalar(0.5), m::AddAnyOrder( BcastConstScalar(1.0), - m::Tanh(m::MultiplyAnyOrder( - BcastConstScalarNear(sqrt(M_2_PI)), - m::AddAnyOrder( - m::Op().Is(existing_gemm), + m::Tanh( + m::MultiplyAnyOrder( + BcastConstScalarNear(sqrt(M_2_PI)), + m::AddAnyOrder( + m::Op().Is(slice_or_bitcast ? slice_or_bitcast + : existing_gemm), + m::MultiplyAnyOrder( + BcastConstScalarNear(0.044715), m::MultiplyAnyOrder( - BcastConstScalarNear(0.044715), + m::Op().Is(slice_or_bitcast + ? slice_or_bitcast + : existing_gemm), m::MultiplyAnyOrder( - m::Op().Is(existing_gemm), - m::MultiplyAnyOrder( - m::Op().Is(existing_gemm), - m::Op().Is(existing_gemm)) - .WithOneUser()) + m::Op().Is(slice_or_bitcast + ? slice_or_bitcast + : existing_gemm), + m::Op().Is(slice_or_bitcast + ? slice_or_bitcast + : existing_gemm)) .WithOneUser()) .WithOneUser()) .WithOneUser()) .WithOneUser()) + .WithOneUser()) .WithOneUser())))) { - return FuseGeluActivation(instr, existing_gemm); + return FuseGeluActivation(instr, existing_gemm, slice_or_bitcast); } - return OkStatus(); + return absl::OkStatus(); + } + + // Fuse the scaling of an FP8 GEMM into the Custom Call. + Status HandleDivide(HloInstruction *instr) override { + return absl::OkStatus(); } Status HandleAdd(HloInstruction *instr) override { - HloInstruction *bias, *existing_gemm; + HloInstruction *bias, *existing_gemm = nullptr; HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; - // Attempt to elide broadcast and fuse addition of a vector bias into GEMM, - // including when slicing is applied to the result. + HloInstruction *optional_bitcast = nullptr; + // Attempt to elide broadcast and fuse addition of a vector bias into + // GEMM, including when slicing is applied to the result. if (Match(instr, m::AddAnyOrder( - OptionalSlice( - &optional_slice, - CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()) + OptionalBitcast( + &optional_bitcast, + OptionalSlice( + &optional_slice, + CublasLtMatmul(&existing_gemm).WithOneUser()) + .WithOneUser()) .WithOneUser(), m::Broadcast(&bias, OptionalConvert(&optional_convert, m::Op()))))) { - TF_ASSIGN_OR_RETURN(bool was_fused, - FuseVectorBiasAdd(instr, bias, existing_gemm, - optional_slice, optional_convert)); + TF_ASSIGN_OR_RETURN( + bool was_fused, + FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice, + optional_convert, optional_bitcast)); if (was_fused) { - return OkStatus(); + return absl::OkStatus(); } } - // Attempt to elide broadcast and fuse addition of a vector bias into // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd. // add(bitcast(gemm(a, b)), broadcast(bias)) -> // bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) -> // bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd) // - if (Match(instr, - m::AddAnyOrder( - m::Bitcast(CublasLtMatmul(&existing_gemm).WithOneUser()) - .WithOneUser(), - m::Broadcast(&bias, m::Op()).WithOneUser()))) { + if (Match( + instr, + m::AddAnyOrder( + m::Bitcast(CublasLtMatmul(&existing_gemm).WithOneUser()) + .WithOneUser(), + m::Broadcast(&bias, m::Op()).WithOneUser()))) { TF_ASSIGN_OR_RETURN( HloInstruction * new_add, MakeBinaryHlo(HloOpcode::kAdd, existing_gemm, @@ -575,7 +576,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // transformation, but it doesn't hurt anything. if (Match(instr, m::AddAnyOrder( - m::Bitcast(GemmOrCublasLtMatmul(&existing_gemm).WithOneUser()) + m::Bitcast( + GemmOrCublasLtMatmul(&existing_gemm).WithOneUser()) .WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { HloInstruction *new_bitcast = @@ -590,13 +592,65 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr = new_add; } + // Attempt to fuse matrix bias into gemm with optional convert + // add(convert(gemm(a, b)), c) -> gemm(a, b, c) + // add(gemm(a, b), c) -> gemm(a, b, c) if (Match(instr, - m::AddAnyOrder(GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(), - m::Op(&bias).WithPredicate(is_not_broadcast)))) { - return FuseMatrixBiasAdd(instr, bias, existing_gemm); + m::AddAnyOrder( + m::AnyOf( + GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(), + m::Convert( + GemmOrCublasLtMatmul(&existing_gemm).WithOneUser()) + .WithOneUser()), + m::Op(&bias).WithPredicate(is_not_broadcast)))) { + TF_ASSIGN_OR_RETURN(auto gemm_backend_config, + existing_gemm->backend_config()); + // check if type combination is supported here + TF_ASSIGN_OR_RETURN( + bool types_are_supported, + IsLegacyCublasMatmul(*existing_gemm) + ? TypesAreSupportedByLegacyCublas(*existing_gemm, + gemm_backend_config, instr) + : TypesAreSupportedByCublasLt(*existing_gemm, gemm_backend_config, + instr)); + + // for mix type gemm, only fuse add if there is no consumers + // ROOT add + // ROOT tuple(add) + bool has_no_consumer = + instr->shape().element_type() == + existing_gemm->shape().element_type() || + instr->user_count() == 0 || + (instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kTuple && + instr->users()[0]->user_count() == 0); + + if (types_are_supported && has_no_consumer) { + return FuseMatrixBiasAdd(instr, bias, existing_gemm); + } + } + + HloInstruction *optional_bitcast_matrix = nullptr; + HloInstruction *optional_slice_matrix = nullptr; + if (Match(instr, + m::AddAnyOrder( + OptionalBitcast( + &optional_bitcast_matrix, + OptionalSlice(&optional_slice_matrix, + GemmOrCublasLtMatmul(&existing_gemm) + .WithOneUser())) + .WithOneUser(), + m::Op(&bias).WithPredicate(is_not_broadcast)))) { + // The matrix bias must not be FP8, see + // https://docs.nvidia.com/cuda/cublas/index.html. + if (!IsF8Type(bias)) { + return FuseMatrixBiasAdd(instr, bias, existing_gemm, + optional_bitcast_matrix, + optional_slice_matrix); + } } - return OkStatus(); + return absl::OkStatus(); } Status HandleMaximum(HloInstruction *instr) override { @@ -609,477 +663,79 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { m::AnyOf( m::Slice( &optional_slice_or_bitcast, - CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()), + CublasLtMatmul(&existing_gemm).WithOneUser()), m::Bitcast( &optional_slice_or_bitcast, - CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()), - CublasLtMatmulMaybeF8(&existing_gemm)) + CublasLtMatmul(&existing_gemm).WithOneUser()), + CublasLtMatmul(&existing_gemm)) .WithOneUser(), m::Broadcast(&zeros, m::ConstantScalar(0))))) { TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm, optional_slice_or_bitcast)); } - return OkStatus(); + return absl::OkStatus(); } Status HandleConvert(HloInstruction *instr) override { - HloInstruction *clamp_lower, *clamp_upper, *d_scale, *existing_gemm, - *binary; - - // Attempt to elide the scaling and conversion of the result of an FP8 - // GEMM, including the optional calculation of the maximum of the absolute - // values before scaling, and adapt the Custom Call. - if (Match(instr, - m::Convert( - m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), - m::AnyOf( - m::Divide( - &binary, - m::CustomCall(&existing_gemm, - {kCublasLtMatmulF8CallTarget}), - m::Broadcast(m::Op(&d_scale))), - m::MultiplyAnyOrder( - &binary, - m::CustomCall(&existing_gemm, - {kCublasLtMatmulF8CallTarget}), - m::Broadcast(m::Op(&d_scale)))), - m::Broadcast(m::ConstantScalar(&clamp_upper))) - .WithOneUser()))) { - return F8ConvertD( - instr, existing_gemm, d_scale, clamp_lower, clamp_upper, - /*mult_scale=*/binary->opcode() == HloOpcode::kMultiply); - } - return OkStatus(); + return absl::OkStatus(); } - StatusOr CreateF8CustomCall(HloInstruction *instr, - GemmBackendConfig &gemm_backend_config, - HloInstruction *a, HloInstruction *b, - HloInstruction *a_scale, - HloInstruction *b_scale, bool a_mult_scale, - bool b_mult_scale, - std::vector a_unary_ops, - std::vector b_unary_ops) { -#if GOOGLE_CUDA - auto cuda_compute_capability = - std::get(gpu_version_); - - // FP8 GEMM kernels are only available on Hopper and newer architectures. - if (!cuda_compute_capability.IsAtLeast( - se::CudaComputeCapability::HOPPER)) { - VLOG(1) << "FP8 Custom Calls require Hopper or newer architecture."; - return false; - } -#if CUDA_VERSION < 11080 - // FP8 GEMM kernels are only available with CUDA 11.8 and above - VLOG(1) << "FP8 Custom Calls require CUDA 11.8 or newer."; - return false; -#endif // CUDA_VERSION - - // cuBLASLt FP8 GEMM kernels require one of the two operands to be in - // F8E4M3FN format. - if (a->shape().element_type() == F8E5M2 && - b->shape().element_type() == F8E5M2) { - VLOG(1) - << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. The element type of one of the operands " - "must be F8E4M3FN."; - return false; - } - - absl::Span batch_dims = - gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions(); - - // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32 - // format. Set the factors to one when no scaling factors were captured. - Literal one_literal = LiteralUtil::One(F32); - HloInstruction *one = instr->AddInstruction( - HloInstruction::CreateConstant(one_literal.Clone())); - std::array mult_scale{a_mult_scale, b_mult_scale}; - std::array scales{a_scale, b_scale}, inv_scales, - scales_f32; - for (int i = 0; i < scales.size(); ++i) { - if (scales[i]) { - if (!ShapeUtil::IsScalar(scales[i]->shape())) { - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. The scaling factors must be " - "scalars."; - return false; - } - if (!mult_scale[i]) { - inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary( - scales[i]->shape(), HloOpcode::kDivide, one, scales[i])); - } - scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i]; - if (scales_f32[i]->shape().element_type() != F32) { - scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeScalarShape(F32), scales_f32[i])); - } - } else { - scales_f32[i] = one; - } - } - - PrimitiveType c_type; - switch (instr->shape().element_type()) { - case F8E4M3FN: - case F8E5M2: - case BF16: - c_type = BF16; - break; - case F16: - c_type = F16; - break; - case F32: - c_type = F32; - break; - default: - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. Output element type must be " - "F8E4M3FN, F8E5M2, BF16, F16 or F32. Actual element type is " - << PrimitiveType_Name(instr->shape().element_type()); - return false; - } - - // Fuse the possible addition of a matrix bias here to enable the subsequent - // fusion of the scaling and conversion of D into the Custom Call. Fusing - // a matrix bias is only supported with CUDA 12 and above. - HloInstruction *c = nullptr; -#if CUDA_VERSION > 12000 - if (instr->user_count() == 1 && - instr->users()[0]->opcode() == HloOpcode::kAdd) { - HloInstruction *add = instr->users()[0]; - HloInstruction *bias = add->mutable_operand(!add->operand_index(instr)); - if (bias->opcode() != HloOpcode::kBroadcast) { - c = bias; - gemm_backend_config.set_beta(1.0); - TF_RETURN_IF_ERROR(ReplaceInstruction(add, instr)); - } - } -#endif // CUDA_VERSION > 12000 - // If a matrix bias was not fused, set C to a matrix of zeros. - if (!c) { - Literal c_literal = LiteralUtil::Zero(c_type); - HloInstruction *c_const = instr->AddInstruction( - HloInstruction::CreateConstant(c_literal.Clone())); - c = instr->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType(instr->shape(), c_type), c_const, {})); - } - - // Each operand must have exactly one contracting and one non-contracting - // dimension. - absl::Span a_contracting_dims = - gemm_backend_config.dot_dimension_numbers() - .lhs_contracting_dimensions(); - absl::Span b_contracting_dims = - gemm_backend_config.dot_dimension_numbers() - .rhs_contracting_dimensions(); - if (a_contracting_dims.size() != 1 || b_contracting_dims.size() != 1) { - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. A and B must have one contracting " - "dimension."; - return false; - } - if ((a_unary_ops.empty() ? a : a_unary_ops.back()) - ->shape() - .dimensions_size() - - batch_dims.size() != - 2 || - (b_unary_ops.empty() ? b : b_unary_ops.back()) - ->shape() - .dimensions_size() - - batch_dims.size() != - 2) { - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << "into FP8 Custom Call. A and B must have one non-contracting " - "dimension."; - return false; - } - - // Sequentially apply the collected unary and pad ops to the unconverted and - // unscaled operands. - auto shift_unary_ops = - [&instr](HloInstruction *&x, - std::vector &x_unary_ops) -> void { - for (HloInstruction *unary_op : x_unary_ops) { - std::vector operands = {x}; - if (unary_op->opcode() == HloOpcode::kPad) { - HloInstruction *convert = - instr->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(unary_op->operand(1)->shape(), - x->shape().element_type()), - unary_op->mutable_operand(1))); - operands.emplace_back(convert); - } - x = instr->AddInstruction(unary_op->CloneWithNewOperands( - ShapeUtil::MakeShapeWithDenseLayout( - x->shape().element_type(), unary_op->shape().dimensions(), - unary_op->shape().layout().minor_to_major()), - operands)); - } - return; - }; - shift_unary_ops(a, a_unary_ops); - shift_unary_ops(b, b_unary_ops); - - TF_ASSIGN_OR_RETURN(bool a_is_col_major, - MatrixIsColumnMajor(*instr, gemm_backend_config, "a")); - TF_ASSIGN_OR_RETURN(bool b_is_col_major, - MatrixIsColumnMajor(*instr, gemm_backend_config, "b")); - - DotDimensionNumbers *dim_nums = - gemm_backend_config.mutable_dot_dimension_numbers(); - int batch_dim_offset = batch_dims.size(); - - // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to - // be row-major. If A is column-major, swap the contracting and - // non-contracting dimension and transpose the matrix to effectively make it - // column-major. - // TODO(philipphack): Remove once cuBLASLt supports A being column-major - if (a_is_col_major) { - CHECK(a_contracting_dims[0] == batch_dim_offset || - a_contracting_dims[0] == batch_dim_offset + 1); - if (a_contracting_dims[0] == batch_dim_offset) { - dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1); - } else { - dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset); - } - a = TransposeMatrix(a, a_contracting_dims[0], batch_dims); - } - - // Similarly, cuBLASLt requires the second operand to be column-major, so - // make it column-major if it is currently row-major. - if (!b_is_col_major) { - CHECK(b_contracting_dims[0] == batch_dim_offset || - b_contracting_dims[0] == batch_dim_offset + 1); - if (b_contracting_dims[0] == batch_dim_offset) { - dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1); - } else { - dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset); - } - b = TransposeMatrix(b, b_contracting_dims[0], batch_dims); - } - - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - auto pad_operand = [&instr, &batch_dims](HloInstruction *&x) -> void { - PaddingConfig padding_config; - Shape padded_shape = x->shape(); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(x->shape().dimensions(i), 16); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_dimension - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - padded_shape.set_dimensions(i, padded_dimension); - } - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - }; - pad_operand(a); - pad_operand(b); - pad_operand(c); - - HloInstruction *new_custom_call = - instr->AddInstruction(HloInstruction::CreateCustomCall( - ShapeUtil::MakeShapeWithDenseLayout( - instr->shape().element_type(), c->shape().dimensions(), - instr->shape().layout().minor_to_major()), - {a, b, c, scales_f32[0], scales_f32[1], one, one}, - kCublasLtMatmulF8CallTarget)); - - TF_RETURN_IF_ERROR( - new_custom_call->set_backend_config(gemm_backend_config)); - TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call)); - - // Slice the result of the GEMM if the operands were padded. - HloInstruction *slice = nullptr; - if (c->shape().dimensions() != instr->shape().dimensions()) { - std::vector start_indices(instr->shape().rank(), 0); - std::vector strides(instr->shape().rank(), 1); - slice = instr->AddInstruction(HloInstruction::CreateSlice( - instr->shape(), new_custom_call, start_indices, - instr->shape().dimensions(), strides)); - } - TF_RETURN_IF_ERROR( - ReplaceInstruction(instr, slice ? slice : new_custom_call)); - return true; -#else // TENSORFLOW_USE_ROCM - return false; -#endif + static bool IsCuda(const GpuVersion &gpu_version) { + return std::holds_alternative(gpu_version); } - Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm, - HloInstruction *d_scale, HloInstruction *clamp_lower, - HloInstruction *clamp_upper, bool mult_scale = false) { - // Verify the data types and the operands of clamp. - if (instr->shape().element_type() == F8E4M3FN) { - if (!clamp_lower->literal().IsAllFloat(static_cast( - std::numeric_limits::lowest())) || - !clamp_upper->literal().IsAllFloat(static_cast( - std::numeric_limits::max()))) { - return OkStatus(); - } - } else if (instr->shape().element_type() == F8E5M2) { - if (!clamp_lower->literal().IsAllFloat(static_cast( - std::numeric_limits::lowest())) || - !clamp_upper->literal().IsAllFloat(static_cast( - std::numeric_limits::max()))) { - return OkStatus(); - } - } else { - return OkStatus(); - } - - if (!ShapeUtil::IsScalar(d_scale->shape())) { - return OkStatus(); - } - - // The possible second user of the GEMM must be the calculation of the - // maximum of the absolute value of the result of the GEMM. Since it is - // unknown in what form this operation will be used, it is identified in a - // top-down approach by inspecting the users of the GEMM. - const std::vector gemm_users = existing_gemm->users(); - HloInstruction *reduce_damax = nullptr; - if (gemm_users.size() == 2) { - // In the presence of a ReLU activation, the abs instruction is elided - // since abs(ReLU(x)) = ReLU(x). - TF_ASSIGN_OR_RETURN(auto config, - existing_gemm->backend_config()); - for (int i = 0; i < gemm_users.size(); ++i) { - HloInstruction *maybe_reduce = nullptr; - if (gemm_users[i]->opcode() == HloOpcode::kAbs) { - if (gemm_users[i]->users().size() != 1) continue; - maybe_reduce = gemm_users[i]->users()[0]; - } else { - // If there is no Abs instruction, relu is required as epilogue to - // ensure all values are nonnegative. - if (config.epilogue() != GemmBackendConfig::BIAS_RELU && - config.epilogue() != GemmBackendConfig::RELU) - continue; - maybe_reduce = gemm_users[i]; - } - - if (maybe_reduce->opcode() == HloOpcode::kReduce && - maybe_reduce->operands().size() == 2 && - maybe_reduce->operand(1)->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(maybe_reduce->operand(1)->shape())) { - HloInstruction *reduce = maybe_reduce; - HloComputation *reduce_comp = reduce->to_apply(); - HloInstruction *reduce_comp_root = reduce_comp->root_instruction(); - if (reduce->operand(1)->literal().GetAsDouble({}) <= 0. && - reduce_comp_root->opcode() == HloOpcode::kMaximum && - reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter && - reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter) { - reduce_damax = reduce; - } - } - } - if (!reduce_damax) { - return OkStatus(); - } - } else if (gemm_users.size() > 2) { - return OkStatus(); + static StatusOr GetCudaComputeCapability( + const GpuVersion &gpu_version) { + auto *cuda_cc = std::get_if(&gpu_version); + if (cuda_cc == nullptr) { + return absl::InvalidArgumentError("Compute Capability is not CUDA."); } + return *cuda_cc; + } - // Change the data type of C to BF16 as required by cuBLASLt for GEMMs with - // FP8 outputs (see cuBLASLt documentation). - if (existing_gemm->operand(2)->shape().element_type() != BF16 && - existing_gemm->operand(2)->shape().element_type() != F16) { - TF_ASSIGN_OR_RETURN(auto gemm_backend_config, - existing_gemm->backend_config()); - if (gemm_backend_config.beta() == 1.0) { - VLOG(1) << "The scaling and conversion of the result of " - << existing_gemm->ToShortString() - << " is not fused into the FP8 Custom Call because it " - "conflicts with the existing fusion of the addition of a " - "matrix bias with element type other than BF16 or F16."; - return OkStatus(); - } else { - Literal c_literal = LiteralUtil::Zero(BF16); - HloInstruction *c = instr->AddInstruction( - HloInstruction::CreateConstant(c_literal.Clone())); - HloInstruction *c_bcast = - instr->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType(instr->shape(), BF16), c, {})); - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, c_bcast)); - } - } - - // If necessary, invert the scaling factor of D and convert to F32. - if (!mult_scale) { - Literal one_literal = LiteralUtil::One(d_scale->shape().element_type()); - HloInstruction *one = instr->AddInstruction( - HloInstruction::CreateConstant(one_literal.Clone())); - d_scale = instr->AddInstruction(HloInstruction::CreateBinary( - d_scale->shape(), HloOpcode::kDivide, one, d_scale)); - } - if (d_scale->shape().element_type() != F32) { - d_scale = instr->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeScalarShape(F32), d_scale)); - } - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(6, d_scale)); - - // If present, elide the calculation of the maximum of the absolute values - // of the result of the GEMM. - if (reduce_damax) { - return F8AddDAmax(instr, existing_gemm, reduce_damax); - } - - std::unique_ptr new_gemm = - existing_gemm->CloneWithNewShape(instr->shape()); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm))); - - return OkStatus(); + static bool IsRocm(const GpuVersion &gpu_version) { + return std::holds_alternative(gpu_version); } - // Adds a scalar DAmax return value to an FP8 GEMM. - Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm, - HloInstruction *reduce_damax) { - // Change the output shape of the Custom Call to tuple(D, DAmax). - Shape damax_shape = ShapeUtil::MakeScalarShape(F32); - Shape tuple_shape = - ShapeUtil::MakeTupleShape({instr->shape(), damax_shape}); - HloInstruction *gemm_and_damax = - instr->AddInstruction(existing_gemm->CloneWithNewShape(tuple_shape)); - - // Obtain D and DAmax separately from the output tuple. - HloInstruction *d = - instr->AddInstruction(HloInstruction::CreateGetTupleElement( - instr->shape(), gemm_and_damax, 0)); - HloInstruction *damax = instr->AddInstruction( - HloInstruction::CreateGetTupleElement(damax_shape, gemm_and_damax, 1)); - - // Convert DAmax from FP32 to the requested type and elide reduce. - HloInstruction *damax_converted = instr->AddInstruction( - HloInstruction::CreateConvert(reduce_damax->shape(), damax)); - TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d)); - - return OkStatus(); + static StatusOr GetRocmComputeCapability( + const GpuVersion &gpu_version) { + auto rocm_cc = std::get_if(&gpu_version); + if (rocm_cc == nullptr) { + return absl::InvalidArgumentError("Compute Capability is not ROCm."); + } + return *rocm_cc; } + // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add + // instruction in the following form: + // Add(OptionalBitcast(OptionalSlice(gemm)), bias) + // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced + // when the inputs of the gemm are possibly padded. Bitcast is introduced to + // handle high rank input. Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, - const HloInstruction *gemm, - HloInstruction *bitcast = nullptr) { - TF_RET_CHECK(bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())); + const HloInstruction *gemm, + HloInstruction *bitcast = nullptr, + HloInstruction *slice = nullptr) { + TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(), + bitcast ? bitcast->shape() + : slice ? slice->shape() + : gemm->shape())); // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only // supports fixed values for alpha/beta. if (gemm->shape().element_type() == S32) { - return OkStatus(); + return absl::OkStatus(); } + // To ensure correctness, only slices that chop off the ends of dimensions + // are supported. + if (slice) { + int slice_op_dim = slice->operand(0)->shape().rank(); + if (slice->slice_starts() != std::vector(slice_op_dim, 0) || + slice->slice_strides() != std::vector(slice_op_dim, 1)) { + return absl::OkStatus(); + } + } // Cublas gemm overwrites the bias matrix, so fusion is only possible if the // gemm is the only user. CublasLt gemm can operate out-of-place. bool can_overwrite_bias = [bias]() { @@ -1111,8 +767,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { }(); bool want_to_fuse_bias = IsCublasLtMatmul(*gemm) || can_overwrite_bias; - auto config = gemm->backend_config().value(); - + TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); // It is possible to fuse into a cublasLt matmul that already has a vector // bias, but no other epilogue will commute with the matrix bias add. bool supported_epilogue = @@ -1121,18 +776,29 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if ((config.beta() != 0) || !want_to_fuse_bias || (gemm->user_count() != 1) || !supported_epilogue) { - return OkStatus(); + return absl::OkStatus(); } config.set_beta(1.0); std::vector operands(gemm->operands().begin(), gemm->operands().end()); - operands.insert(operands.begin() + 2, MaybeConstantFoldBias(bias)); + HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias); + if (bitcast) { + maybe_constant_folded_bias = + instr->AddInstruction(HloInstruction::CreateBitcast( + slice->shape(), maybe_constant_folded_bias)); + } + + maybe_constant_folded_bias = + PadOperandToTargetShape(gemm->shape(), maybe_constant_folded_bias); + + operands.insert(operands.begin() + 2, maybe_constant_folded_bias); std::unique_ptr fused_op = gemm->CloneWithNewOperands(gemm->shape(), operands); - + // set output shape to bias shape if mix type + fused_op->mutable_shape()->set_element_type(bias->shape().element_type()); TF_RETURN_IF_ERROR(fused_op->set_backend_config(config)); // Choose whether the bias must alias the output. Legacy cublas GEMMs must @@ -1156,8 +822,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ->set_output_to_operand_aliasing({{{}, {2, {}}}}); } TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); + if (slice) { + fused_op = slice->CloneWithNewOperands( + slice->shape(), + {slice->parent()->AddInstruction(std::move(fused_op))}); + } - if (bitcast != nullptr) { + if (bitcast) { fused_op = bitcast->CloneWithNewOperands( bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(fused_op))}); @@ -1166,14 +837,23 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(instr, std::move(fused_op)); } + // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add + // instruction in the following form: + // Add(OptionalBitcast(OptionalSlice(gemm)), Broadcast(OptionalConvert())) + // where 'gemm' is expected to be a cuBLAS custom_call. The optional + // convert is only used for F8 matmuls as cublasLt has specific constraints + // on the vector bias type for such matmuls. The optional bitcast is + // necessary to handle high rank input cases. StatusOr FuseVectorBiasAdd(HloInstruction *instr, - HloInstruction *broadcast, - HloInstruction *gemm, - HloInstruction *slice = nullptr, - HloInstruction *convert = nullptr) { - TF_RET_CHECK(ShapeUtil::Compatible( - broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); - + HloInstruction *broadcast, + HloInstruction *gemm, + HloInstruction *slice = nullptr, + HloInstruction *convert = nullptr, + HloInstruction *bitcast = nullptr) { + if (!bitcast) { + TF_RET_CHECK(ShapeUtil::Compatible( + broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); + } // Verify that the data type is supported by Epilogue Fusion. if (!SupportsEpilogueFusion(gemm->shape().element_type())) { return false; @@ -1181,8 +861,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *bias = broadcast->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); - + TF_ASSIGN_OR_RETURN(auto config, + gemm->backend_config()); // # output column dims == # non-contracting rhs operand dims. const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); size_t num_col_dims = gemm->operand(1)->shape().rank() - @@ -1199,7 +879,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // physical dimensions of the gemm output. absl::Span broadcast_dims = broadcast->dimensions(); for (size_t i = 0; i < num_col_dims; ++i) { - int64_t dim = gemm->shape().layout().minor_to_major(i); + int64_t dim = + (bitcast ? bitcast : gemm)->shape().layout().minor_to_major(i); // Find the corresponding dimension from the bias vector. auto it = absl::c_find(broadcast_dims, dim); @@ -1216,47 +897,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::vector operands(gemm->operands().begin(), gemm->operands().end()); - // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just - // fuse matrix bias. - if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && - config.beta() != 0.0) { - return true; - } - - if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && - bias->shape().element_type() == F32) { - if (convert == nullptr) { - return false; - } - - HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0); - auto compatible_bias_type = [](const PrimitiveType bias_type, - const PrimitiveType output_type) { - if (bias_type == BF16) { - return output_type == F8E4M3FN || output_type == F8E5M2 || - output_type == F32 || output_type == BF16; - } else if (bias_type == F16) { - return output_type == F16 || output_type == F8E4M3FN || - output_type == F8E5M2; - } - return false; - }; - - // cuBLAS LT does not support FP32 biases on matmuls with FP8 inputs, - // even if the matmul output is FP32. We do not unconditionally convert - // the bias to a supported precision (F16 or BF16) because this lowers - // precision. Instead, we only fuse the bias if the bias itself is a - // convert from F16 or BF16, fusing the input of the convert instruction - // to the matmul. - if (compatible_bias_type(bias_f16_or_bf16->shape().element_type(), - gemm->shape().element_type())) { - bias = bias_f16_or_bf16; - } else { - VLOG(1) << "Epilogue fusion of FP32 vector bias into FP8 GEMM is " - "currently not supported. See the cublasLT support matrix."; - return false; - } - } // Replace add(gemm, broadcast) with fused new_gemm. operands.push_back(bias); @@ -1265,44 +905,51 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm->CloneWithNewOperands(gemm->shape(), operands); TF_RETURN_IF_ERROR(result->set_backend_config(config)); TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get())); - if (slice != nullptr) { + if (slice) { result = slice->CloneWithNewOperands( slice->shape(), {slice->parent()->AddInstruction(std::move(result))}); } + if (bitcast) { + result = bitcast->CloneWithNewOperands( + bitcast->shape(), + {bitcast->parent()->AddInstruction(std::move(result))}); + } TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result))); return true; } - Status FuseReluActivation(HloInstruction *instr, HloInstruction *broadcast, - HloInstruction *gemm, - HloInstruction *slice_or_bitcast = nullptr) { + Status FuseReluActivation(HloInstruction *instr, + HloInstruction *broadcast, + HloInstruction *gemm, + HloInstruction *slice_or_bitcast = nullptr) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice_or_bitcast ? slice_or_bitcast->shape() : gemm->shape()))); if (!SupportsEpilogueFusion(gemm->shape().element_type())) { - return OkStatus(); + return absl::OkStatus(); } if (gemm->user_count() != 1) { - return OkStatus(); + return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto config, + gemm->backend_config()); if (config.epilogue() == GemmBackendConfig::DEFAULT) { config.set_epilogue(GemmBackendConfig::RELU); } else if (config.epilogue() == GemmBackendConfig::BIAS) { config.set_epilogue(GemmBackendConfig::BIAS_RELU); } else { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr result = gemm->Clone(); TF_RETURN_IF_ERROR(result->set_backend_config(config)); TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get())); - if (slice_or_bitcast != nullptr) { + if (slice_or_bitcast) { result = slice_or_bitcast->CloneWithNewOperands( slice_or_bitcast->shape(), {slice_or_bitcast->parent()->AddInstruction(std::move(result))}); @@ -1311,15 +958,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(instr, std::move(result)); } - Status FuseGeluActivation(HloInstruction *multiply, HloInstruction *gemm) { + Status FuseGeluActivation(HloInstruction *multiply, + HloInstruction *gemm, + HloInstruction *slice_or_bitcast = nullptr) { if (!SupportsEpilogueFusion(gemm->shape().element_type())) { - return OkStatus(); + return absl::OkStatus(); } - // There are four users of the gemm output within the GELU calculation. bool has_aux = gemm->user_count() > 4; - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto config, + gemm->backend_config()); + if (config.epilogue() == GemmBackendConfig::DEFAULT) { config.set_epilogue(has_aux ? GemmBackendConfig::GELU_AUX : GemmBackendConfig::GELU); @@ -1327,7 +977,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX : GemmBackendConfig::BIAS_GELU); } else { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr output = gemm->CloneWithNewShape( @@ -1336,6 +986,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(output->set_backend_config(config)); TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get())); + if (slice_or_bitcast) { + output = slice_or_bitcast->CloneWithNewOperands( + slice_or_bitcast->shape(), + {gemm->parent()->AddInstruction(std::move(output))}); + } + if (has_aux) { HloInstruction *tuple_output = gemm->parent()->AddInstruction(std::move(output)); @@ -1380,78 +1036,52 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (gemm_is_supported_by_cublas_lt) { return absl::string_view(kCublasLtMatmulCallTarget); } - + return InternalError("Unsupported hipblaslt gemm config: %s", instr.ToString()); // This case is not supported by cublasLt, fallback to legacy cublas. - return absl::string_view(kGemmCallTarget); + ///return absl::string_view(kGemmCallTarget); } - StatusOr TypesAreSupportedByCublasLt( - const HloInstruction &instr) const { + StatusOr TypesAreSupportedByLegacyCublas( + const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config, + const HloInstruction *bias = nullptr) const { // Figure out the Atype/Btype. const PrimitiveType a_dtype = instr.operand(0)->shape().element_type(); const PrimitiveType b_dtype = instr.operand(1)->shape().element_type(); - // cublasLt has a defined set of combinations of types that it supports. - // Figure out the computeType and scaleType. - TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, - AsBlasDataType(instr.shape().element_type())); - TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, - GetBlasComputationType( - a_dtype, instr.shape().element_type(), - stream_executor::blas::kDefaultComputePrecision)); + const PrimitiveType output_type = + bias ? bias->shape().element_type() : instr.shape().element_type(); + const std::array supported_type = { + PrimitiveType::S8, PrimitiveType::F16, PrimitiveType::BF16, + PrimitiveType::F32, PrimitiveType::S32, PrimitiveType::F64, + PrimitiveType::C64, PrimitiveType::C128}; + // legacy cublas has a defined set of combinations of types that it + // supports. Figure out the computeType and scaleType. + if (!absl::c_linear_search(supported_type, output_type)) return false; + TF_ASSIGN_OR_RETURN(auto output_dtype, se::gpu::AsBlasDataType(output_type)); + TF_ASSIGN_OR_RETURN(auto blas_a_dtype, se::gpu::AsBlasDataType(a_dtype)); + // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // algorithm) here? Why do we use the default? + TF_ASSIGN_OR_RETURN(auto compute_type, se::gpu::GetBlasComputationType( + blas_a_dtype, output_dtype, + se::blas::kDefaultComputePrecision)); se::blas::DataType scale_type = - cublas_lt::GetScaleType(output_dtype, compute_type); + se::gpu::GetScaleType(output_dtype, compute_type); using se::blas::ComputationType; using se::blas::DataType; - // This matrix of supported types is taken directly from cublasLt + // This matrix of supported types is taken directly from cublas // documentation. - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul - const std::array< - std::tuple, - 32> - supported_type_combinations = {{ - // FP8 types: - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E4M3FN, DataType::kBF16}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E4M3FN, DataType::kHalf}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E4M3FN, DataType::kFloat}, - - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E5M2, DataType::kBF16}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E5M2, DataType::kF8E4M3FN}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E5M2, DataType::kF8E5M2}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E5M2, DataType::kHalf}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, - PrimitiveType::F8E5M2, DataType::kFloat}, - - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, - PrimitiveType::F8E4M3FN, DataType::kBF16}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, - PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, - PrimitiveType::F8E4M3FN, DataType::kF8E5M2}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, - PrimitiveType::F8E4M3FN, DataType::kHalf}, - {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, - PrimitiveType::F8E4M3FN, DataType::kFloat}, - - // Other data types: + // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex + + using TypeCombinations = std::initializer_list>; + + const TypeCombinations supported_type_combinations = { {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16, PrimitiveType::F16, DataType::kHalf}, {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8, PrimitiveType::S8, DataType::kInt32}, - {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8, - PrimitiveType::S8, DataType::kInt8}, {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, PrimitiveType::BF16, DataType::kBF16}, @@ -1491,7 +1121,149 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128, PrimitiveType::C128, DataType::kComplexDouble}, - }}; + }; + + return absl::c_linear_search( + supported_type_combinations, + std::make_tuple(compute_type, scale_type, a_dtype, b_dtype, + output_dtype)); + } + + StatusOr TypesAreSupportedByCublasLt( + const HloInstruction &instr, const GemmBackendConfig &backend_config, + const HloInstruction *bias = nullptr) const { + // Figure out the Atype/Btype. + const PrimitiveType a_dtype = instr.operand(0)->shape().element_type(); + const PrimitiveType b_dtype = instr.operand(1)->shape().element_type(); + const PrimitiveType output_type = + bias ? bias->shape().element_type() : instr.shape().element_type(); + const std::array supported_type = { + PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, + PrimitiveType::S8, PrimitiveType::F16, + PrimitiveType::BF16, PrimitiveType::F32, + PrimitiveType::S32, PrimitiveType::F64, + PrimitiveType::C64, PrimitiveType::C128}; + if (!absl::c_linear_search(supported_type, output_type)) return false; + // cublasLt has a defined set of combinations of types that it supports. + // Figure out the computeType and scaleType. + TF_ASSIGN_OR_RETURN(auto output_dtype, se::gpu::AsBlasDataType(output_type)); + TF_ASSIGN_OR_RETURN(auto blas_a_dtype, se::gpu::AsBlasDataType(a_dtype)); + + const int max_precision = *absl::c_max_element( + backend_config.precision_config().operand_precision()); + + TF_ASSIGN_OR_RETURN( + auto compute_type, + se::gpu::GetBlasComputationType( + blas_a_dtype, output_dtype, max_precision)); + se::blas::DataType scale_type = + se::gpu::GetScaleType(output_dtype, compute_type); + + using se::blas::ComputationType; + using se::blas::DataType; + using TypeCombinations = std::initializer_list>; + // This matrix of supported types is taken directly from cublasLt + // documentation. + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul + const TypeCombinations supported_cublas_type_combinations = { + // FP8 types: + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kBF16}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kHalf}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E4M3FN, DataType::kFloat}, + + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kBF16}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kF8E4M3FN}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kF8E5M2}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kHalf}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN, + // PrimitiveType::F8E5M2, DataType::kFloat}, + + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kBF16}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kF8E5M2}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kHalf}, + // {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, + // PrimitiveType::F8E4M3FN, DataType::kFloat}, + // There would be an entry here for A/BType complex int8, but we do + // not support that type. + {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64, + PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kF16AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + // The next 4 may be supported by hipblaslt, but they are not + // covered by any unit tests + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kBF16AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kTF32AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64, + PrimitiveType::F64, DataType::kDouble}, + {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128, + PrimitiveType::C128, DataType::kComplexDouble}, + }; + if (IsCuda(gpu_version_) && + absl::c_linear_search(supported_cublas_type_combinations, + std::tuple{compute_type, scale_type, a_dtype, + b_dtype, output_dtype})) { + return true; + } + const TypeCombinations supported_type_combinations = { + // Other data types: + + {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8, + PrimitiveType::S8, DataType::kInt32}, + {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8, + PrimitiveType::S8, DataType::kInt8}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8, + PrimitiveType::S8, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kFloat}, + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kBF16}, + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kFloat}, + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kHalf}, + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + }; return absl::c_linear_search( supported_type_combinations, @@ -1507,6 +1279,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const DotDimensionNumbers &dot_dims = gemm_backend_config.dot_dimension_numbers(); + // We use ALG_UNSET and kDefaultComputePrecision because we don't care about + // the precision, just the layout, since we're just checking if the matrix + // is column-major. TF_ASSIGN_OR_RETURN( GemmConfig gemm_config, GemmConfig::For( @@ -1516,8 +1291,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { dot_dims.rhs_contracting_dimensions(), /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(), gemm_backend_config.alpha_imag(), gemm_backend_config.beta(), - /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision, - gemm_backend_config.grad_x(), gemm_backend_config.grad_y())); + /*algorithm*/ se::blas::kDefaultAlgorithm, se::blas::kDefaultComputePrecision, + se::gpu::BlasLt::Epilogue::kDefault)); if (matrix_name == "lhs" || matrix_name == "a") { return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor; @@ -1527,24 +1302,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return gemm_config.output_layout.order == MatrixLayout::Order::kColumnMajor; } else { - return InternalError("Invalid matrix name."); + return Internal("Invalid matrix name."); } } StatusOr GemmIsSupportedByCublasLt( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config) const { - - // hipblas-lt is available since r2.14 (r2.13 and r2.12 don't have). - if (std::holds_alternative(gpu_version_)) - return false; - const HloInstruction *lhs = instr.operand(0); const HloInstruction *rhs = instr.operand(1); const Shape &output_shape = instr.shape(); - TF_ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt, - TypesAreSupportedByCublasLt(instr)); + TF_ASSIGN_OR_RETURN( + bool types_are_supported_by_cublas_lt, + TypesAreSupportedByCublasLt(instr, gemm_backend_config)); if (!types_are_supported_by_cublas_lt) { return false; } @@ -1566,6 +1337,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } + TF_ASSIGN_OR_RETURN(bool output_is_column_major, + MatrixIsColumnMajor(instr, gemm_backend_config)); + + // if (auto isrocm = std::get_if(&gpu_version_); + // isrocm) { + // if (!isrocm->has_hipblaslt()) { + // return false; + // } + // } + // 2. cublasLt does not support rhs col dimension size > 4194240 for // C64. constexpr int kMaxDimensionSize{4194240}; @@ -1574,14 +1355,24 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return true; } + if (std::holds_alternative(gpu_version_)) { + auto cuda_compute_capability_ = + std::get(gpu_version_); + if (cuda_compute_capability_.IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + // cuBlasLt has an implementation for complex data with compute type + // 32F_FAST_32TF that uses tensor cores and that is free from the + // restriction. This implementation only works on Ampere + // architecture though (where TF32 was introduced). + return true; + } + } // Get the rhs non-contracting dimensions as they will eventually be at the // cublasLt level. std::vector rhs_non_contracting_dims; const DotDimensionNumbers &dot_dims = gemm_backend_config.dot_dimension_numbers(); - TF_ASSIGN_OR_RETURN(bool output_is_column_major, - MatrixIsColumnMajor(instr, gemm_backend_config)); if (!output_is_column_major) { // cublasLt's matmul output is column major by default. This gemm requires // the output to be in row major. Later we will swap lhs & rhs (and @@ -1608,40 +1399,108 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return lhs_non_contracting_dimension_size <= kMaxDimensionSize; } - // Turns an F8 dot into an F16 dot, converting operands to F16 and - // converting the output back to F8. - StatusOr TurnF8DotIntoF16Dot(HloInstruction *instr) { - DCHECK(IsF8Type(instr)); - DCHECK(IsF8Type(instr->operand(0))); - DCHECK(IsF8Type(instr->operand(1))); - - // Convert operands to F16 - for (int i = 0; i < 2; ++i) { - Shape operand_f16_shape = instr->operand(i)->shape(); - operand_f16_shape.set_element_type(F16); - HloInstruction *convert = - instr->AddInstruction(HloInstruction::CreateConvert( - operand_f16_shape, instr->mutable_operand(i))); - TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert)); +}; + +// Rewriter that adds a workspace to legacy cuBLAS custom calls. We run it +// separately after gemm rewriter, so that we can do pattern matching without +// having to match output tuples. +class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmWorkspaceRewriteVisitor( + const GpuVersion &gpu_version) + : gpu_version_(gpu_version) {} + + Status HandleCustomCall(HloInstruction *instr) override { + bool has_aux_output = false; + + // add workspace only for cublas-lt calls + if (instr->custom_call_target() != kCublasLtMatmulCallTarget) { + return absl::OkStatus(); } + TF_ASSIGN_OR_RETURN(auto config, + instr->backend_config()); + GemmBackendConfig_Epilogue epilogue = config.epilogue(); + TF_ASSIGN_OR_RETURN( + has_aux_output, + gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue)); + + if (!((instr->shape().IsTuple() && + instr->shape().tuple_shapes_size() == + has_aux_output + /*config.damax_output()*/ + 1) || + instr->shape().IsArray())) { + return absl::OkStatus(); + } + + auto *cuda_cc = std::get_if(&gpu_version_); - // Clone instruction and convert output to F8 - Shape output_f16_shape = instr->shape(); - output_f16_shape.set_element_type(F16); - HloInstruction *f16_dot = - instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape)); - HloInstruction *convert_to_f8 = instr->AddInstruction( - HloInstruction::CreateConvert(instr->shape(), f16_dot)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8)); - return f16_dot; + // Pass a user-managed workspace to legacy cuBLAS operations, as + // otherwise cuBLAS will use its own internal pool which will be competing + // with XLA allocator for device memory. + int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace + : cuda_cc->IsAtLeast(se::CudaComputeCapability::HOPPER) + ? GemmConfig::kHopperWorkspace + : GemmConfig::kDefaultWorkspace; + + // We do not know the workspace size required by cuBLAS, but we can guess + // that in a worst case cuBLAS will transpose all operands into tiled + // layout optimal for the tensor cores. It doesn't make sense to allocate a + // larger workspace. + // + // TODO(ezhulenev): This is not based on any measurement, just a common + // sense, we should tweak it to find the minimal workspace size. + if (instr->custom_call_target() == kGemmCallTarget) { + int64_t operands_byte_size = 0; + for (auto &operand : instr->operands()) { + operands_byte_size += ShapeUtil::ByteSizeOf(operand->shape()); + } + workspace = std::min(workspace, operands_byte_size); + } + + // Append workspace buffer to instruction outputs. + std::vector output_shapes = instr->shape().IsArray() + ? std::vector{instr->shape()} + : instr->shape().tuple_shapes(); + output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace})); + Shape output_shape = ShapeUtil::MakeTupleShape(output_shapes); + + // Clone custom call with a new shape. + HloInstruction *new_call = instr->AddInstruction( + instr->CloneWithNewOperands(output_shape, instr->operands())); + + // Update operand aliasing if it was a fused gemm with aliased output. + auto *custom_call = xla::Cast(new_call); + if (!custom_call->output_to_operand_aliasing().empty()) { + custom_call->set_output_to_operand_aliasing({{{0}, {2, {}}}}); + } + + if (instr->shape().IsTuple()) { + for (auto user : instr->users()) { + auto user_get_tuple = + dynamic_cast(user); + TF_RET_CHECK(user_get_tuple); + HloInstruction *get_output = + instr->AddInstruction(HloInstruction::CreateGetTupleElement( + new_call, user_get_tuple->tuple_index())); + TF_RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output)); + } + return absl::OkStatus(); + } else { + HloInstruction *get_output = instr->AddInstruction( + HloInstruction::CreateGetTupleElement(new_call, 0)); + return ReplaceInstruction(instr, get_output); + } } + + private: + GpuVersion gpu_version_; }; -StatusOr RunOnComputation( - HloComputation *computation, - GpuVersion gpu_version) { +StatusOr RunOnComputation(HloComputation *computation, + const GpuVersion& gpu_version) { GemmRewriterVisitor visitor(gpu_version); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version); + TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor)); return visitor.changed(); } @@ -1656,8 +1515,8 @@ StatusOr GemmRewriter::Run( bool changed = false; for (HloComputation *computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN( - bool result, RunOnComputation(computation, gpu_version_)); + TF_ASSIGN_OR_RETURN(bool result, + RunOnComputation(computation, gpu_version_)); changed |= result; } return changed; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.h index 571742b38e7e25..c2e23c1790ffe3 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.h @@ -31,9 +31,10 @@ namespace gpu { // (kMultiply (kDot A B) alpha) // (kMultiply C beta)) // -// where A, B, C are matrixes and `alpha` and `beta` are host constants. -// The additional requirement is that C has no other users (otherwise, -// it does not make sense to fuse it inside the custom call). +// where A, B, C are matrices or vectors and `alpha` and `beta` are host +// constants. In matrix-vector multiplication, one operand must be a matrix and +// the other must be a vector. The additional requirement is that C has no other +// users (otherwise, it does not make sense to fuse it inside the custom call). // // Both multiplication and addition can be avoided (equivalent to setting // `alpha` to one and `beta` to zero). @@ -44,7 +45,8 @@ namespace gpu { // stored in the backend config. class GemmRewriter : public HloModulePass { public: - explicit GemmRewriter(GpuVersion gpu_version); + + GemmRewriter(GpuVersion gpu_version); absl::string_view name() const override { return "cublas-gemm-rewriter"; } using HloPassInterface::Run; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index b5d35bf1ca25a2..06798bd25da629 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -39,9 +39,14 @@ GemmThunk::GemmThunk(ThunkInfo thunk_info, GemmConfig config, Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(3) << "Running GEMM thunk"; const BufferAllocations& allocs = *params.buffer_allocations; + + se::DeviceMemoryBase workspace_buffer{}; return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), allocs.GetDeviceAddress(rhs_buffer_), - allocs.GetDeviceAddress(output_buffer_), params.stream); + allocs.GetDeviceAddress(output_buffer_), + workspace_buffer, + /* deterministic_ops */false, + params.stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b2f802c37e5a44..1eb64e3726fbd2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -105,7 +105,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h" #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_shape_verifier.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_stats.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" @@ -168,6 +167,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" @@ -190,7 +190,6 @@ limitations under the License. #include "rocm/rocm_config.h" #endif #if GOOGLE_CUDA -#include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/triton_autotuner.h" #endif // GOOGLE_CUDA @@ -241,6 +240,16 @@ tsl::thread::ThreadPool* GetThreadPool( return &*overriding_thread_pool; } } + +AutotuneConfig GetAutotuneConfig( + se::StreamExecutor* stream_exec, const DebugOptions& debug_options, + se::DeviceMemoryAllocator* device_allocator) { + + CHECK(stream_exec != nullptr); + return AutotuneConfig{DeviceConfig{stream_exec, device_allocator}, + debug_options}; +} + } // end anonymous namespace StatusOr> @@ -901,17 +910,13 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( .VerifyReshapeIsBitcast(), /*debug_only=*/true); - AutotuningConfig autotune_config = - stream_exec ? AutotuningConfig{DeviceConfig{stream_exec, - options.device_allocator}} - : AutotuningConfig{DevicelessConfig{ - gpu_target_config.device_description_str}}; + auto autotune_config = GetAutotuneConfig(stream_exec, debug_options, + options.device_allocator); // Linearize collective schedule under SPMD partitioning if online autotuning // of convolutions is enabled. const bool enable_collecive_schedule_linearizer_for_spmd = hlo_module->config().use_spmd_partitioning() && - autotune_config.is_online() && GpuConvAlgorithmPicker::IsEnabled(hlo_module); if (enable_collecive_schedule_linearizer_for_spmd) { @@ -919,23 +924,13 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( RequiresCollectiveScheduleLinearizer); } - if (autotune_config.is_offline()) { - GpuConvAlgorithmPicker::ClearAutotuneResults(); - TF_RETURN_IF_ERROR( - GpuConvAlgorithmPicker::LoadAutotuneResults(*autotune_results)); -#if GOOGLE_CUDA - GemmAlgorithmPicker::ClearAutotuneResults(); - TF_RETURN_IF_ERROR( - GemmAlgorithmPicker::LoadAutotuneResults(*autotune_results)); - TritonAutotuner::ClearAutotuneResults(); - TF_RETURN_IF_ERROR(TritonAutotuner::LoadAutotuneResults(*autotune_results)); -#endif // GOOGLE_CUDA - } if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) { pipeline.AddPass(autotune_config); } + + pipeline.AddPass( + autotune_config); #if GOOGLE_CUDA - pipeline.AddPass(autotune_config); const HloModuleConfig& module_config = hlo_module->config(); std::optional overriding_thread_pool; tsl::thread::ThreadPool* thread_pool = GetThreadPool( @@ -989,6 +984,10 @@ StatusOr> GpuCompiler::RunHloPasses( [&] { return absl::StrCat("HLO Transforms:", module->name()); }, tsl::profiler::TraceMeLevel::kInfo); + const DebugOptions& debug_opts = module->config().debug_options(); + auto cfg = GetAutotuneConfig(stream_exec, debug_opts, nullptr); + TF_RETURN_IF_ERROR(AutotunerUtil::LoadAutotuneResultsFromFileOnce(cfg)); + GpuTargetConfig gpu_target_config = GetGpuTargetConfig(stream_exec); TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), stream_exec, options, gpu_target_config, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 4f3b8c12a02287..4e9b4248728573 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -324,7 +324,7 @@ struct ConvCacheStats { absl::Mutex autotune_cache_mu(absl::kConstInit); auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = - *new absl::flat_hash_map(); + *new absl::flat_hash_map(); auto& autotune_cache_stats ABSL_GUARDED_BY(autotune_cache_mu) = *new ConvCacheStats(); @@ -382,25 +382,8 @@ bool ShouldCheckConv(const HloModuleConfig& hlo_module_config) { StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( const HloCustomCallInstruction* instr) { - // If in deviceless mode, return the result from the autotune_cache. - if (auto deviceless_config = std::get_if(&config_)) { - auto device_description_str = deviceless_config->model_str; - AutotuneCacheKey key = - AutotuneCacheKeyFromInstruction(instr, device_description_str); - absl::MutexLock autotune_lock(&autotune_cache_mu); - auto it = autotune_cache.find(key); - if (it != autotune_cache.end()) { - return it->second; - } - - // Return an autotune result with algo id -1, which means that we autotune - // at runtime. - AutotuneResult result; - result.mutable_algorithm()->set_algo_id(-1); - return result; - } - - se::StreamExecutor* stream_exec = std::get(config_).stream_exec; + + auto stream_exec = config_.GetExecutor(); // Don't run this function concurrently on the same GPU. // // This is a bit of a hack and doesn't protect us against arbitrary concurrent @@ -416,7 +399,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( // which can greatly improve both stability (deterministic numeric results // within a process for a given input) and performance (2x speedup on some // models). - AutotuneCacheKey key = AutotuneCacheKeyFromInstruction( + AutotuneConvCacheKey key = AutotunerUtil::ConvCacheKeyFromInstruction( instr, stream_exec->GetDeviceDescription().model_str()); { absl::MutexLock autotune_lock(&autotune_cache_mu); @@ -438,8 +421,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( // allocator either points to this->allocator_ or, if that's null, to a // se::StreamExecutorMemoryAllocator for stream_exec. - se::DeviceMemoryAllocator* device_allocator = - std::get(config_).allocator; + se::DeviceMemoryAllocator* device_allocator = config_.GetAllocator(); se::DeviceMemoryAllocator* allocator; optional se_allocator; if (device_allocator != nullptr) { @@ -532,7 +514,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( initialize_buffer(result_buffer, result_shape); // Get canonical HLO. - std::string canonical_hlo = std::get<1>(AutotuneCacheKeyFromInstruction( + std::string canonical_hlo = std::get<1>(AutotunerUtil::ConvCacheKeyFromInstruction( instr, stream_exec->GetDeviceDescription().model_str())); TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr)); @@ -559,7 +541,7 @@ GpuConvAlgorithmPicker::AutotuneOneConvRunner( const AutotuneRuntimeArguments& runtime_arguments) { auto alg = runner->ToAlgorithmDesc(); - se::StreamExecutor* stream_exec = std::get(config_).stream_exec; + se::StreamExecutor* stream_exec = config_.GetExecutor(); XLA_SCOPED_LOGGING_TIMER_LEVEL( absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ", alg.ToString()), @@ -798,7 +780,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream, std::optional instruction_info, const AutotuneRuntimeArguments& runtime_arguments) { - se::StreamExecutor* stream_exec = std::get(config_).stream_exec; + se::StreamExecutor* stream_exec = config_.GetExecutor(); std::string instr_str = instruction_info.has_value() ? instruction_info->instr_str.c_str() @@ -914,8 +896,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( } TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm, - PickBestResult(profile_results, instr_str, - runtime_arguments.hlo_module_config)); + PickBestResult(profile_results, instr_str)); return selected_algorithm; } #endif @@ -962,7 +943,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( instr->GetModule()->config().debug_options(); const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); - se::StreamExecutor* stream_exec = std::get(config_).stream_exec; + se::StreamExecutor* stream_exec = config_.GetExecutor(); const auto device_ordinal = stream_exec->device_ordinal(); std::vector operand_buffers; @@ -1072,8 +1053,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( } TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm, - PickBestResult(profile_results, instr->ToString(), - instr->GetModule()->config())); + PickBestResult(profile_results, instr->ToString())); return selected_algorithm; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h index 5b2905891f0ab3..cd968d49b2c2c0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" @@ -79,7 +79,7 @@ class GpuConvAlgorithmPicker : public HloModulePass { static Status WriteAutotuneResults(AutotuneResults* results); static Status LoadAutotuneResults(const AutotuneResults& results); - explicit GpuConvAlgorithmPicker(AutotuningConfig config) : config_(config) {} + explicit GpuConvAlgorithmPicker(AutotuneConfig config) : config_(config) {} absl::string_view name() const override { return "gpu-conv-algorithm-picker"; @@ -172,7 +172,7 @@ class GpuConvAlgorithmPicker : public HloModulePass { se::DeviceMemoryAllocator* allocator, se::Stream* stream); private: - AutotuningConfig config_; + AutotuneConfig config_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h b/tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h deleted file mode 100644 index b80e96e89f463a..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SERIALIZABLE_AUTOTUNER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SERIALIZABLE_AUTOTUNER_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/autotune_results.pb.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { -namespace gpu { - -struct DeviceConfig { - se::StreamExecutor* stream_exec; // never null - - // If the `allocator` parameter is not null, we will use it to allocate temp - // memory while timing the various convolution algorithms. If it's null, - // we'll use the default allocator on the StreamExecutor. - se::DeviceMemoryAllocator* allocator; // may be null -}; - -struct DevicelessConfig { - // The human-readable description of the device. It can be found by using - // stream_exec->GetDeviceDescription().model_str() when the stream executor - // is available. - std::string model_str; - - // A field to determine the architecture of the device. We only pick an - // algorithm for non-Ampere architectures. - se::CudaComputeCapability cuda_compute_capability{0, 0}; -}; - -struct AutotuningConfig : public std::variant { - using std::variant::variant; - bool is_offline() const { - return std::holds_alternative(*this); - } - bool is_online() const { return std::holds_alternative(*this); } -}; - -using AutotuneCacheKey = - std::tupleGetDeviceDescription().model_str()*/, - std::string /* instr->ToString(HloPrintOptions::Canonical()) */>; - -using AutotuneCacheMap = - absl::flat_hash_map; - -inline AutotuneCacheKey AutotuneCacheKeyFromInstruction( - const HloInstruction* instr, absl::string_view model_str) { - auto options = HloPrintOptions::Canonical(); - options.set_print_backend_config(true); - return std::make_tuple(std::string(model_str), instr->ToString(options)); -} - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SERIALIZABLE_AUTOTUNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc new file mode 100644 index 00000000000000..b7e8f3176dc404 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc @@ -0,0 +1,185 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +//#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" + +namespace xla { +namespace gpu { + +struct MatmulPlanCache { + + static MatmulPlanCache& i(const se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::vector< std::unique_ptr< MatmulPlanCache > > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (meta.size() < dev_id) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) res.reset(new MatmulPlanCache()); + return *res; + } + + template < class Func > + StatusOr + GetOrCreate(const std::string& key, Func&& create) { + // each GPU has a different mutex => hence different GPU instances can + // create matmul plans in parallel + absl::MutexLock lock(mutex_.get()); + auto res = map_.emplace(key, se::gpu::BlasLt::MatmulPlanPtr{}); + if(res.second) { // new entry inserted + TF_ASSIGN_OR_RETURN(res.first->second, create()); + } + return res.first->second.get(); + } + +private: + MatmulPlanCache() : mutex_(std::make_unique< absl::Mutex >()) { } + +private: + std::unique_ptr< absl::Mutex > mutex_; + absl::flat_hash_map map_; +}; + +CublasLtMatmulThunk::CublasLtMatmulThunk( + ThunkInfo thunk_info, GemmConfig config, + BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, + BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, + BufferAllocation::Slice bias_buffer /* may be null */, + BufferAllocation::Slice aux_buffer /* may be null */, + BufferAllocation::Slice a_scale /* may be null */, + BufferAllocation::Slice b_scale /* may be null */, + BufferAllocation::Slice c_scale /* may be null */, + BufferAllocation::Slice d_scale /* may be null */, + BufferAllocation::Slice d_amax /* may be null */, + absl::optional workspace) + : Thunk(Kind::kCublasLtMatmul, thunk_info), + gemm_config_(std::move(config)), + a_buffer_(a_buffer), + b_buffer_(b_buffer), + c_buffer_(c_buffer), + d_buffer_(d_buffer), + bias_buffer_(bias_buffer), + aux_buffer_(aux_buffer), + a_scale_buffer_(a_scale), + b_scale_buffer_(b_scale), + c_scale_buffer_(c_scale), + d_scale_buffer_(d_scale), + d_amax_buffer_(d_amax), + workspace_buffer_(workspace) +{ + canonical_hlo_ = se::gpu::ToCSVString(gemm_config_, /*full_string*/true); + // set algorithm ID explicitly to -1 if tuning is disabled! + + if(GetDebugOptionsFromFlags().xla_gpu_autotune_level() == 0) { + gemm_config_.algorithm = se::blas::kDefaultAlgorithm; + } +} + +Status CublasLtMatmulThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + if (!executor->AsBlas()) { + return InternalError("Failed to initialize BLASLT support"); + } + return OkStatus(); +} + +Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { + + TF_ASSIGN_OR_RETURN(auto *plan, GetCachedMatmulPlan(params)); + + VLOG(2) << params.stream->parent()->device_ordinal() << + ": cublas_lt_matmul for: " << canonical_hlo_; + const BufferAllocations& allocs = *params.buffer_allocations; + + se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; + if (bias_buffer_.allocation() != nullptr) { + bias = allocs.GetDeviceAddress(bias_buffer_); + } + if (a_scale_buffer_.allocation() != nullptr) { + a_scale = allocs.GetDeviceAddress(a_scale_buffer_); + } + if (b_scale_buffer_.allocation() != nullptr) { + b_scale = allocs.GetDeviceAddress(b_scale_buffer_); + } + if (c_scale_buffer_.allocation() != nullptr) { + c_scale = allocs.GetDeviceAddress(c_scale_buffer_); + } + if (d_scale_buffer_.allocation() != nullptr) { + d_scale = allocs.GetDeviceAddress(d_scale_buffer_); + } + if (d_amax_buffer_.allocation() != nullptr) { + d_amax = allocs.GetDeviceAddress(d_amax_buffer_); + } + + se::DeviceMemoryBase aux; + if (aux_buffer_.allocation() != nullptr) { + aux = allocs.GetDeviceAddress(aux_buffer_); + } + + absl::optional workspace; + if (workspace_buffer_) { + workspace = allocs.GetDeviceAddress(*workspace_buffer_); + } + + return plan->ExecuteOnStream( + params.stream, allocs.GetDeviceAddress(a_buffer_), + allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), + allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, + d_scale, d_amax, workspace); +} + +auto CublasLtMatmulThunk::GetCachedMatmulPlan( + const ExecuteParams& params) -> StatusOr { + + auto& cache = MatmulPlanCache::i(params.stream); + + auto create = [&]() -> StatusOr { + VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream << + " cfg: " << canonical_hlo_; + + TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + params.stream, gemm_config_)); + + int64_t max_workspace = workspace_buffer_.has_value() + ? workspace_buffer_.value().size() : 0, + algorithm_id = gemm_config_.algorithm; + int64_t num_algorithms = algorithm_id == se::blas::kDefaultAlgorithm ? + 1 : se::gpu::BlasLt::kMaxAlgorithms; + TF_ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms, max_workspace)); + + if (algorithm_id == se::blas::kDefaultAlgorithm && !algorithms.empty()) { + algorithm_id = algorithms[0].id; + } + for(const auto& alg : algorithms) { + if (alg.id == algorithm_id) { + TF_RETURN_IF_ERROR(plan->SetAlgorithm(alg)); + return std::move(plan); + } + } + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[0])); + LOG(WARNING) << "Wrong algorithm ID: " << algorithm_id << " use default instead."; + return std::move(plan); + }; + return cache.GetOrCreate(canonical_hlo_, create); +} + +} // namespace gpu +} // namespace xla + diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h b/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.h similarity index 52% rename from tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h rename to tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.h index da156e322ee395..54a0eba8abdeed 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,41 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ +#ifndef TENSORFLOW_COMPILER_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ +#define TENSORFLOW_COMPILER_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ +#include #include -#include #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" namespace xla { namespace gpu { class CublasLtMatmulThunk : public Thunk { public: - CublasLtMatmulThunk(ThunkInfo thunk_info, cublas_lt::MatmulPlan plan, - int64_t algorithm_idx, BufferAllocation::Slice a_buffer, - BufferAllocation::Slice b_buffer, - BufferAllocation::Slice c_buffer, - BufferAllocation::Slice d_buffer, - BufferAllocation::Slice bias_buffer /* may be null */, - BufferAllocation::Slice aux_buffer /* may be null */, - BufferAllocation::Slice a_scale_buffer /* may be null */, - BufferAllocation::Slice b_scale_buffer /* may be null */, - BufferAllocation::Slice c_scale_buffer /* may be null */, - BufferAllocation::Slice d_scale_buffer /* may be null */, - BufferAllocation::Slice d_amax_buffer /* may be null */); + + CublasLtMatmulThunk( + ThunkInfo thunk_info, GemmConfig config, + BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, + BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, + BufferAllocation::Slice bias_buffer /* may be null */, + BufferAllocation::Slice aux_buffer /* may be null */, + BufferAllocation::Slice a_scale_buffer /* may be null */, + BufferAllocation::Slice b_scale_buffer /* may be null */, + BufferAllocation::Slice c_scale_buffer /* may be null */, + BufferAllocation::Slice d_scale_buffer /* may be null */, + BufferAllocation::Slice d_amax_buffer /* may be null */, + absl::optional workspace_buffer); Status ExecuteOnStream(const ExecuteParams& params) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; private: - cublas_lt::MatmulPlan plan_; - int64_t algorithm_idx_; + StatusOr GetCachedMatmulPlan( + const ExecuteParams& params); + + GemmConfig gemm_config_; + std::string canonical_hlo_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; BufferAllocation::Slice c_buffer_; @@ -59,10 +64,11 @@ class CublasLtMatmulThunk : public Thunk { BufferAllocation::Slice c_scale_buffer_; BufferAllocation::Slice d_scale_buffer_; BufferAllocation::Slice d_amax_buffer_; - std::optional algorithm_; + absl::optional workspace_buffer_; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ +#endif // TENSORFLOW_COMPILER_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ + diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 24e5b0543a7a84..c2f222511b1917 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -47,6 +47,11 @@ bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 2; } +// Return whether the given shape is rank 1 excluding the batch dimensions. +bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { + return shape.rank() == batch_dimensions_size + 1; +} + // Given a shape and a group of contiguous dimensions in the shape, returns // a tuple of three values (major, middle, minor), where major is the size of // the dimensions more major then the given dimensions, minor is the size of @@ -96,6 +101,7 @@ Shape GetShapeFromTensorType(mlir::Value value) { } // namespace + bool IsMatrixMultiplication(const HloInstruction& dot) { if (dot.opcode() != HloOpcode::kDot) { return false; @@ -107,9 +113,10 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || - output_primitive_type == F16 || output_primitive_type == BF16 || - output_primitive_type == F32 || output_primitive_type == F64 || - output_primitive_type == C64 || output_primitive_type == C128) || + output_primitive_type == F16 || + output_primitive_type == BF16 || output_primitive_type == F32 || + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128) || (output_primitive_type == S32 && lhs_shape.element_type() == S8 && rhs_shape.element_type() == S8); bool shapes_are_valid = @@ -120,17 +127,37 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); - if (!shapes_are_valid) { + return shapes_are_valid; +} + +bool IsMatrixVectorMultiplication(const HloInstruction& dot) { + if (dot.opcode() != HloOpcode::kDot) { return false; } + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + PrimitiveType output_primitive_type = dot.shape().element_type(); + bool type_is_allowed = + (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || + output_primitive_type == F16 || output_primitive_type == BF16 || + output_primitive_type == F32 || output_primitive_type == F64 || + output_primitive_type == C64 || output_primitive_type == C128) || + (output_primitive_type == S32 && lhs_shape.element_type() == S8 && + rhs_shape.element_type() == S8); - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + bool shapes_are_valid = + type_is_allowed && + ((IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) && + IsRank1(rhs_shape, dim_numbers.lhs_batch_dimensions_size())) || + (IsRank1(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) && + IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()))) && + IsRank1(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) && + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); - return true; + return shapes_are_valid; } Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 1a63f980f64e11..14454e4e226229 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -43,6 +43,7 @@ inline constexpr int64_t kMinTotalDimensionsToTransposeTiled = 64 * 128; // This function should never return "true" on instructions after // GemmRewriter pass has finished. bool IsMatrixMultiplication(const HloInstruction& dot); +bool IsMatrixVectorMultiplication(const HloInstruction& dot); inline constexpr int64_t WarpSize() { return 32; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index cae0f25ab2a70c..ccdccd3ed0a6f7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -91,6 +91,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" @@ -143,7 +144,6 @@ limitations under the License. #include "tensorflow/tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA -#include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_triton.h" #endif // GOOGLE_CUDA @@ -1147,8 +1147,6 @@ Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) { return OkStatus(); } -#if GOOGLE_CUDA - Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { auto matmul = mlir::dyn_cast(op); TF_RET_CHECK(matmul != nullptr); @@ -1163,21 +1161,26 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias())); } - BufferAllocation::Slice aux; + BufferAllocation::Slice aux, workspace; if (matmul.getAux() != nullptr) { TF_ASSIGN_OR_RETURN(aux, GetAllocationSlice(matmul.getAux())); } + if (matmul.getWorkspace() != nullptr) { + TF_ASSIGN_OR_RETURN(workspace, GetAllocationSlice(matmul.getWorkspace())); + } - TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, - cublas_lt::MatmulPlan::For(matmul)); + TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(matmul)); auto thunk = std::make_unique( - GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d, - bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); + GetThunkInfo(op), std::move(config), a, b, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, + workspace); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); } +#if GOOGLE_CUDA + Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(mlir::Operation* op) { auto matmul = mlir::dyn_cast(op); TF_RET_CHECK(matmul != nullptr); @@ -5688,10 +5691,10 @@ Status IrEmitterUnnested::EmitOp(mlir::Operation* op) { return EmitGemmThunk(op); } -#if GOOGLE_CUDA if (mlir::isa(op)) { return EmitCublasLtMatmulThunk(op); } +#if GOOGLE_CUDA if (mlir::isa(op)) { return EmitCublasLtMatmulThunkF8(op); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 90e94ff1b6aff4..474f4495c785fa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -198,13 +198,13 @@ class IrEmitterUnnested : public IrEmitter { Status EmitConvolutionThunk(mlir::Operation* op); Status EmitGemmThunk(mlir::Operation* op); #if GOOGLE_CUDA - Status EmitCublasLtMatmulThunk(mlir::Operation* op); Status EmitCublasLtMatmulThunkF8(mlir::Operation* op); Status EmitConvolutionReorderThunk(mlir::Operation* op); Status EmitTritonFusion(mlir::Operation* op, tensorflow::AutotuneResult::TritonGemmKey& config); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + Status EmitCublasLtMatmulThunk(mlir::Operation* op); Status EmitCholeskyThunk(mlir::Operation* op); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM Status EmitCustomCallThunk(mlir::Operation* op); diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc index b8bcb7e670f71b..7554bfd2417168 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,38 +16,30 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include +#include #include #include +#include +#include #include -#include #include #include -#include "absl/algorithm/container.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" -#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/stream_executor/blas.h" -#include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/platform/statusor.h" - -#if GOOGLE_CUDA -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" -#include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" -#endif // GOOGLE_CUDA +#include "tensorflow/compiler/xla/status_macros.h" namespace xla { namespace gpu { +using se::blas::DataType; +using se::gpu::BlasLt; + StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims) { @@ -66,10 +58,42 @@ StatusOr> GetNonContractingDims( return non_contracting_dims; } -StatusOr GetBatchRowColumnShape(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims) { +const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( + const HloInstruction& dot, const int operand_number) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + if (operand_number == 0) { + return dimension_numbers.lhs_batch_dimensions(); + } + return dimension_numbers.rhs_batch_dimensions(); +} + +StatusOr ContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + if (operand_number == 0) { + TF_RET_CHECK(dimension_numbers.lhs_contracting_dimensions().size() == 1); + return dimension_numbers.lhs_contracting_dimensions(0); + } + TF_RET_CHECK(dimension_numbers.rhs_contracting_dimensions().size() == 1); + return dimension_numbers.rhs_contracting_dimensions(0); +} + +StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { + TF_ASSIGN_OR_RETURN(int64_t contracting_dim, + ContractingDimensionIndex(dot, operand_number)); + TF_ASSIGN_OR_RETURN( + std::vector non_contracting_dims, + GetNonContractingDims(dot.operand(operand_number)->shape(), + BatchDimensionsForOperand(dot, operand_number), + {contracting_dim})); + TF_RET_CHECK(non_contracting_dims.size() == 1); + return non_contracting_dims.front(); +} + +StatusOr GetBatchRowColumnShape( + const Shape& shape, absl::Span batch_dims, + absl::Span row_dims, absl::Span col_dims) { TF_RET_CHECK(shape.has_layout()); std::vector minor_to_major; @@ -77,7 +101,8 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, // The GeMM output always has its layout set such that the batch, row, and // col dim groups are each laid out physically sequentially. GeMM operands // must, therefore, be laid out similarly. - auto check_physically_sequential = [&](absl::Span dims) { + auto check_physically_sequential = + [&](absl::Span dims) -> Status { for (auto it = dims.rbegin(); it != dims.rend(); ++it) { // NOTE: `i` is incremented as we check the dimensions. if (*it != shape.layout().minor_to_major()[i++]) @@ -126,7 +151,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, int64_t num_rows = shape.dimensions(1); int64_t num_cols = shape.dimensions(2); - MatrixLayout::Order order = MatrixLayout::Order::kRowMajor; + Order order{Order::kRowMajor}; int64_t leading_dim_stride = num_cols; int64_t batch_stride = num_rows * num_cols; @@ -137,7 +162,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, case 012: // (B,R,C) (major-to-minor) break; case 021: // (B,C,R) - order = MatrixLayout::Order::kColumnMajor; + order = Order::kColumnMajor; leading_dim_stride = num_rows; break; case 0102: // (R,B,C) @@ -145,7 +170,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, batch_stride = num_cols; break; case 0201: // (C,B,R) - order = MatrixLayout::Order::kColumnMajor; + order = Order::kColumnMajor; leading_dim_stride = batch_size * num_rows; batch_stride = num_rows; break; @@ -153,11 +178,13 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return Unimplemented("batch in most minor dimension"); } - if (batch_size == 1) batch_stride = 0; - return MatrixLayout{ - shape.element_type(), num_rows, num_cols, order, - leading_dim_stride, batch_size, batch_stride, - }; + if (batch_size == 1) { + batch_stride = 0; + } + + TF_ASSIGN_OR_RETURN(auto dtype, se::gpu::AsBlasDataType(shape.element_type())); + return MatrixLayout(dtype, num_rows, num_cols, order, + batch_size, leading_dim_stride, batch_stride); } /*static*/ StatusOr MatrixLayout::For( @@ -169,11 +196,9 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return MatrixLayout::For(batch_row_col_shape); } -/*static*/ StatusOr MatrixLayout::For(const Shape& shape, - size_t lhs_num_batch_dims, - size_t lhs_num_row_dims, - size_t rhs_num_batch_dims, - size_t rhs_num_col_dims) { +/*static*/ StatusOr MatrixLayout::For( + const Shape& shape, size_t lhs_num_batch_dims, size_t lhs_num_row_dims, + size_t rhs_num_batch_dims, size_t rhs_num_col_dims) { size_t num_batch_dims = std::max(lhs_num_batch_dims, rhs_num_batch_dims); TF_RET_CHECK(shape.rank() == @@ -190,11 +215,6 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return MatrixLayout::For(shape, batch_dims, row_dims, col_dims); } -void MatrixLayout::Transpose() { - std::swap(num_rows, num_cols); - order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor; -} - namespace { // Returns the relative order of 'dims' as indices from 0 to dims.size() - 1. // Let 'indices' be the returned vector, then it holds that @@ -211,7 +231,10 @@ std::vector NormalizedRelativeOrder(absl::Span dims) { } // namespace StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, - int64_t operand_idx) { + int64_t operand_idx) { + // if (Cast(&dot)->sparse_operands()) { + // return false; + // } TF_RET_CHECK(dot.opcode() == HloOpcode::kDot); TF_RET_CHECK(dot.operand_count() > operand_idx); @@ -261,23 +284,9 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision, - bool gx, bool gy) { - return GemmConfig::For(lhs_shape, lhs_batch_dims, lhs_contracting_dims, - rhs_shape, rhs_batch_dims, rhs_contracting_dims, - output_shape, output_shape, alpha_real, alpha_imag, - beta, algorithm, compute_precision, - gx, gy); -} + int64_t algorithm, int64_t compute_precision, + BlasLt::Epilogue epilogue) { -/*static*/ StatusOr GemmConfig::For( - const Shape& lhs_shape, absl::Span lhs_batch_dims, - absl::Span lhs_contracting_dims, const Shape& rhs_shape, - absl::Span rhs_batch_dims, - absl::Span rhs_contracting_dims, const Shape& c_shape, - const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision, - bool gx, bool gy) { absl::Span lhs_col_dims = lhs_contracting_dims; TF_ASSIGN_OR_RETURN( std::vector lhs_row_dims, @@ -315,21 +324,19 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, TF_ASSIGN_OR_RETURN(MatrixLayout output_layout, MatrixLayout::For(output_shape, output_batch_dims, output_row_dims, output_col_dims)); + Shape c_matrix_shape = output_shape; TF_ASSIGN_OR_RETURN(MatrixLayout c_layout, - MatrixLayout::For(c_shape, output_batch_dims, + MatrixLayout::For(c_matrix_shape, output_batch_dims, output_row_dims, output_col_dims)); // TODO(cjfj): We should also check that the batch, contracting and // non-contracting dimensions match in size and relative physical location. // TODO(philipphack): Check the remaining dimensions in the FP8 case once // cuBLASLt supports the NN configuration. - if (lhs_shape.element_type() != F8E4M3FN && - lhs_shape.element_type() != F8E5M2) { - TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows); - TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows); - TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols); - } + TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows); + TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows); + TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols); TF_RET_CHECK(c_layout.num_rows == output_layout.num_rows); TF_RET_CHECK(c_layout.num_cols == output_layout.num_cols); TF_RET_CHECK((lhs_layout.batch_size == output_layout.batch_size) || @@ -338,8 +345,6 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, (rhs_layout.batch_size == 1)); switch (output_shape.element_type()) { - case F8E4M3FN: - case F8E5M2: case F16: case BF16: case F32: @@ -351,40 +356,38 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, break; case S32: TF_RET_CHECK(alpha_imag == 0); - if (lhs_layout.dtype != PrimitiveType::S8 || - rhs_layout.dtype != PrimitiveType::S8) { - return InternalError( - "For int32 gemm output only int8 input is supported, got input: " - "%s, %s", - primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype)); + if (lhs_layout.dtype != DataType::kInt8 || + rhs_layout.dtype != DataType::kInt8) { + return Internal( + "For int32 gemm output only int8 input is supported !"); } break; default: - return InternalError("Unexpected GEMM datatype: %s", - primitive_util::LowercasePrimitiveTypeName( - output_shape.element_type())); + return Internal("Unexpected GEMM datatype: %s", + primitive_util::LowercasePrimitiveTypeName( + output_shape.element_type())); } - return GemmConfig{ - lhs_layout, - rhs_layout, - c_layout, - output_layout, - {alpha_real, alpha_imag}, - beta, - algorithm, - compute_precision, - gx, - gy - }; + return GemmConfig( + se::gpu::GemmConfig{ + .lhs_layout = lhs_layout, + .rhs_layout = rhs_layout, + .c_layout = c_layout, + .output_layout = output_layout, + .alpha = {alpha_real, alpha_imag}, + .beta = beta, + .algorithm = algorithm, + .compute_precision = compute_precision, + .epilogue = epilogue}); } -/*static*/ StatusOr GemmConfig::For(const HloInstruction* gemm) { - TF_ASSIGN_OR_RETURN(GemmBackendConfig config, +/*static*/ StatusOr GemmConfig::For( + const HloInstruction* gemm) { + + TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); - std::optional algorithm; + int64_t algorithm = se::blas::kDefaultAlgorithm; if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) { algorithm = config.selected_algorithm(); } @@ -394,44 +397,38 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers(); const Shape& output_shape = gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); - auto attributes = gemm->frontend_attributes().map(); - bool gx = (attributes["grad_x"] == "true"); - bool gy = (attributes["grad_y"] == "true"); + + int64_t precision = se::blas::kDefaultComputePrecision; + for (auto operand_precision : config.precision_config().operand_precision()) { + precision = std::max(precision, static_cast(operand_precision)); + } + TF_ASSIGN_OR_RETURN(auto epilogue, + gpublas_lt::AsBlasLtEpilogue(config.epilogue())); return GemmConfig::For( lhs_shape, dot_dims.lhs_batch_dimensions(), dot_dims.lhs_contracting_dimensions(), rhs_shape, - dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(), - /*output_shape=*/gemm->shape(), config.alpha_real(), config.alpha_imag(), - config.beta(), algorithm, se::blas::kDefaultComputePrecision, - gx, gy); + dot_dims.rhs_batch_dimensions(), + dot_dims.rhs_contracting_dimensions(), + output_shape, config.alpha_real(), config.alpha_imag(), config.beta(), + algorithm, precision, epilogue); } /*static*/ StatusOr GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - std::optional algorithm; + + auto dot_dims = op.getDotDimensionNumbers(); + int64_t algorithm = se::blas::kDefaultAlgorithm; if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); - bool gx=false, gy=false; - auto attr_grad_x = op.getGradX(); - if (attr_grad_x) - gx=attr_grad_x.value(); - auto attr_grad_y = op.getGradY(); - if (attr_grad_y) - gx=attr_grad_y.value(); - + int64_t compute_precision = 0; // Default if (op.getPrecisionConfig().has_value()) { auto precision_config = op.getPrecisionConfig(); for (auto attr : precision_config.value()) { int64_t value = static_cast( attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } + compute_precision = std::max(value, compute_precision); } } - return GemmConfig::For( GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), @@ -439,282 +436,276 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), algorithm, compute_precision, - gx, gy); -} - -StatusOr GetBlasComputationType( - PrimitiveType lhs_dtype, PrimitiveType output_dtype, - int64_t compute_precision) { - switch (output_dtype) { - case F8E5M2: // fall-through - case F8E4M3FN: // fall-through - case F16: // fall-through - case BF16: - // Accumulate in f32 precision. - return se::blas::ComputationType::kF32; - case F32: // fall-through - case C64: -#if GOOGLE_CUDA - if (tsl::tensor_float_32_execution_enabled() && compute_precision <= 1 && - lhs_dtype != F8E4M3FN && lhs_dtype != F8E5M2) { - // CublasLt requires compute type to be F32 for F8 matmul. - return se::blas::ComputationType::kTF32AsF32; - } -#endif - return se::blas::ComputationType::kF32; - case F64: // fall-through - case C128: - return se::blas::ComputationType::kF64; - case S32: - return se::blas::ComputationType::kI32; - default: - return InternalError("GetBlasComputationType: unsupported type"); - } -} - -namespace cublas_lt { - -se::blas::DataType GetScaleType(se::blas::DataType c_type, - se::blas::ComputationType computation_type) { - return ((computation_type == se::blas::ComputationType::kF32) && - (c_type != se::blas::DataType::kComplexFloat)) - ? se::blas::DataType::kFloat - : c_type; + BlasLt::Epilogue::kDefault); } -} // namespace cublas_lt - -namespace { +/*static*/ StatusOr GemmConfig::For( + mlir::lmhlo_gpu::CublasLtMatmulOp op) { -// This struct contains the metadata of a matrix, e.g., its base address and -// dimensions. -struct MatrixDescriptor { - se::DeviceMemoryBase data; - int64_t leading_dim_stride; - int64_t batch_stride; - se::blas::Transpose transpose; - - template - se::DeviceMemory cast() const { - return se::DeviceMemory(data); - } -}; - -// BLAS GeMM's output is column-major. If we require row-major, use identity: -// C^T = (A @ B)^T = B^T @ A^T. -bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, - MatrixLayout& output) { - bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; - if (swap_operands) { - std::swap(lhs, rhs); - lhs.Transpose(); - rhs.Transpose(); - output.Transpose(); + auto dot_dims = op.getDotDimensionNumbers(); + int64_t algorithm = op.getAlgorithm(); + TF_ASSIGN_OR_RETURN(auto epilogue, gpublas_lt::AsBlasLtEpilogue(op.getEpilogue())); + + int64_t compute_precision = 0; // Default + if (op.getPrecisionConfig().has_value()) { + auto precision_config = op.getPrecisionConfig(); + for (auto attr : precision_config.value()) { + int64_t value = static_cast( + attr.template cast().getValue()); + compute_precision = std::max(value, compute_precision); + } } - return swap_operands; + return GemmConfig::For( + GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), + dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), + dot_dims.getRhsBatchingDimensions(), + dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), + op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), + op.getBeta().convertToDouble(), algorithm, compute_precision, + epilogue); } -bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, - MatrixLayout& output, MatrixLayout& c) { - bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; - if (swap_operands) { - std::swap(lhs, rhs); - rhs.Transpose(); - lhs.Transpose(); - c.Transpose(); - output.Transpose(); +StatusOr GemmConfig::GetMatrixDescriptors( + se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, + se::DeviceMemoryBase out_buf) const { + auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout, + se::DeviceMemoryBase data) { + return se::gpu::MatrixDescriptor{ + data, layout.leading_dim_stride, layout.batch_stride, + layout.dtype, + // BLAS is column-major by default. + (layout.order == se::gpu::MatrixLayout::Order::kColumnMajor + ? se::blas::Transpose::kNoTranspose + : se::blas::Transpose::kTranspose)}; + }; + // TODO: make a local copy to prevent modification of layouts, + // but maybe we can modify them once instead during creation ? + se::gpu::MatrixLayout lhs = lhs_layout, rhs = rhs_layout, out = output_layout; + + bool must_swap_operands = MakeOutputColumnMajor(lhs, rhs, out); + if (must_swap_operands) { + std::swap(lhs_buf, rhs_buf); } - return swap_operands; -} -se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) { - // BLAS is column-major by default. - return (order == MatrixLayout::Order::kColumnMajor) - ? se::blas::Transpose::kNoTranspose - : se::blas::Transpose::kTranspose; + se::gpu::OutputMatrixDescriptor out_desc = create_matrix_desc(out, out_buf); + out_desc.batch_size = out.batch_size; + out_desc.m = out.num_rows; + out_desc.n = out.num_cols; + out_desc.k = lhs.num_cols; + // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // algorithm) here? Why do we use the default? + TF_ASSIGN_OR_RETURN(out_desc.compute_type, + se::gpu::GetBlasComputationType( + lhs.dtype, out.dtype, -1)); + + se::gpu::MatrixDescriptor lhs_desc = create_matrix_desc(lhs, lhs_buf), + rhs_desc = create_matrix_desc(rhs, rhs_buf); + + return DescriptorsTuple{lhs_desc, rhs_desc, out_desc, must_swap_operands}; } -MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout, - se::DeviceMemoryBase data) { - return { - data, - layout.leading_dim_stride, - layout.batch_stride, - AsBlasTranspose(layout.order), - }; -} +namespace { -template -Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, - const MatrixDescriptor& rhs, - const MatrixDescriptor& output, Output alpha, - Output beta, se::Stream* stream, - se::blas::AlgorithmType algorithm, - se::blas::ComputePrecision compute_precision, - se::blas::ProfileResult* profile_result, - se::blas::CallContext context) { +template +Status DoGemmWithAlgorithm(const se::gpu::MatrixDescriptor& lhs, + const se::gpu::MatrixDescriptor& rhs, + const se::gpu::OutputMatrixDescriptor& output, + se::DeviceMemoryBase workspace, Scale alpha, + Scale beta, se::Stream* stream, + se::blas::AlgorithmType algorithm, + se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); - PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType(); - PrimitiveType output_type = primitive_util::NativeToPrimitiveType(); TF_ASSIGN_OR_RETURN( se::blas::ComputationType computation_type, - GetBlasComputationType(lhs_type, output_type, compute_precision)); + se::gpu::GetBlasComputationType(lhs.type, + output.type, compute_precision)); se::DeviceMemory output_data(output.data); - if (batch_size != 1) { + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + // Set a workspace for all Blas operations launched below. + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); + + if (output.batch_size != 1) { return stream->ThenBlasGemmStridedBatchedWithAlgorithm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), - rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, - output.leading_dim_stride, output.batch_stride, batch_size, - computation_type, algorithm, compute_precision, profile_result, - context); + lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, + &output_data, output.leading_dim_stride, output.batch_stride, + output.batch_size, computation_type, algorithm, compute_precision, + profile_result, context); } else { return stream->ThenBlasGemmWithAlgorithm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, computation_type, algorithm, - compute_precision, profile_result, context); + lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, + computation_type, algorithm, compute_precision, profile_result, context); } } -template -Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, const MatrixDescriptor& rhs, - const MatrixDescriptor& output, Input alpha, Input beta, - se::Stream* stream, - std::optional algorithm, - se::blas::ComputePrecision compute_precision, - se::blas::ProfileResult* profile_result, - se::blas::CallContext context) { +template +Status DoGemm(const se::gpu::MatrixDescriptor& lhs, + const se::gpu::MatrixDescriptor& rhs, + const se::gpu::OutputMatrixDescriptor& output, + se::DeviceMemoryBase workspace, Scale alpha, Scale beta, + se::Stream* stream, + std::optional algorithm, + se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); - se::DeviceMemory output_data(output.data); + se::DeviceMemory output_data(output.data); if (algorithm) { - return DoGemmWithAlgorithm( - batch_size, m, n, k, lhs, rhs, output, alpha, beta, stream, *algorithm, - compute_precision, profile_result, context); + return DoGemmWithAlgorithm( + lhs, rhs, output, workspace, alpha, beta, stream, + *algorithm, compute_precision, numeric_options, profile_result, + context); } - if (batch_size != 1) { + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + // Set a workspace for all Blas operations launched below. + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); + + if (output.batch_size != 1) { return stream->ThenBlasGemmStridedBatched( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), - rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, - output.leading_dim_stride, output.batch_stride, batch_size, - compute_precision, context); + lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, + &output_data, output.leading_dim_stride, output.batch_stride, + output.batch_size, compute_precision, context); } - return stream->ThenBlasGemm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, compute_precision, - context); + return stream->ThenBlasGemm(lhs.transpose, rhs.transpose, output.m, + output.n, output.k, alpha, lhs.cast(), + lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, + output.leading_dim_stride, compute_precision, context); } +#define DT(E, T) \ + template <> struct DataTypeToNative { using Type = T; } + +template < DataType E > +struct DataTypeToNative {}; + +DT(kFloat, float); +DT(kDouble, double); +DT(kHalf, Eigen::half); +DT(kInt8, int8_t); +DT(kInt32, int32_t); +DT(kComplexFloat, complex64); +DT(kComplexDouble, complex128); +DT(kBF16, Eigen::bfloat16); + +template < DataType E > +using DataType_t = typename DataTypeToNative< E >::Type; + +#undef DT + } // namespace Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, se::Stream* stream, - std::optional algorithm, - se::blas::ProfileResult* profile_result) { + se::DeviceMemoryBase rhs_buffer, + se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase workspace_buffer, + bool deterministic_ops, se::Stream* stream, + std::optional algorithm, + se::blas::ProfileResult* profile_result) { VLOG(2) << "Executing a GemmThunk"; - MatrixLayout lhs_layout = config.lhs_layout; - MatrixLayout rhs_layout = config.rhs_layout; - MatrixLayout output_layout = config.output_layout; - bool must_swap_operands = - MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout); - if (must_swap_operands) { - std::swap(lhs_buffer, rhs_buffer); - } + TF_ASSIGN_OR_RETURN( + GemmConfig::DescriptorsTuple desc, + config.GetMatrixDescriptors(lhs_buffer, rhs_buffer, output_buffer)); - int64_t m = output_layout.num_rows; - int64_t n = output_layout.num_cols; - int64_t k = lhs_layout.num_cols; - MatrixDescriptor lhs = GetMatrixDesc(lhs_layout, lhs_buffer); - MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer); - MatrixDescriptor output = GetMatrixDesc(output_layout, output_buffer); - int64_t batch_size = output_layout.batch_size; + se::NumericOptions numeric_options(deterministic_ops); if (!algorithm) algorithm = config.algorithm; se::blas::CallContext context = se::blas::CallContext::kNone; - if (config.grad_x) { - context = must_swap_operands - ? se::blas::CallContext::kBackpropInput2 - : se::blas::CallContext::kBackpropInput1; - } - if (config.grad_y) { - context = must_swap_operands - ? se::blas::CallContext::kBackpropInput1 - : se::blas::CallContext::kBackpropInput2; + std::tuple operand_types{config.lhs_layout.dtype, config.rhs_layout.dtype, + config.output_layout.dtype}; + + // Skip degenerate gemm with memzero. In general this is not safe, because it + // will suppress NaN propagation, however cuBLAS internally has exactly the + // same optimization for compatibility with NETLIB implementation, so we are + // not making things worse (and cuBLAS optimization is incompatible with CUDA + // graphs, so we are making sure we do not trigger it). + if (config.alpha.real() == 0.0 && config.alpha.imag() == 0.0 && + config.beta == 0.0) { + stream->ThenMemZero(&output_buffer, output_buffer.size()); + return OkStatus(); } - if ((output_layout.dtype == F16 || output_layout.dtype == BF16 || - output_layout.dtype == F32 || output_layout.dtype == F64 || - output_layout.dtype == C64 || output_layout.dtype == C128) && - (lhs_layout.dtype != output_layout.dtype || - rhs_layout.dtype != output_layout.dtype)) { - return InternalError( - "GEMM lhs type(%s) and rhs type(%s) must match output type(%s)", - primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(output_layout.dtype)); +#define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ + if (operand_types == std::tuple{DataType::ATYPE, DataType::BTYPE, DataType::CTYPE}) { \ + using NativeScaleType = DataType_t; \ + using NativeAType = DataType_t; \ + using NativeCType = DataType_t; \ + return DoGemm( \ + desc.lhs, desc.rhs, desc.output, workspace_buffer, \ + static_cast(config.alpha.real()), \ + static_cast(config.beta), stream, \ + algorithm, config.compute_precision, \ + numeric_options, profile_result, context); \ } - switch (output_layout.dtype) { - case S32: - if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo; - return DoGemmWithAlgorithm( - batch_size, m, n, k, lhs, rhs, output, - static_cast(config.alpha.real()), - static_cast(config.beta), stream, *algorithm, - se::blas::kDefaultComputePrecision, profile_result, - context); - case F16: - return DoGemm(batch_size, m, n, k, lhs, rhs, output, - static_cast(config.alpha.real()), - static_cast(config.beta), stream, - algorithm, config.compute_precision, - profile_result, context); - case BF16: - return DoGemm( - batch_size, m, n, k, lhs, rhs, output, - static_cast(config.alpha.real()), - static_cast(config.beta), stream, algorithm, - config.compute_precision, profile_result, context); - case F32: - return DoGemm(batch_size, m, n, k, lhs, rhs, output, - config.alpha.real(), config.beta, stream, algorithm, - config.compute_precision, profile_result, context); - case F64: - return DoGemm(batch_size, m, n, k, lhs, rhs, output, - config.alpha.real(), config.beta, stream, algorithm, - config.compute_precision, profile_result, context); - case C64: - return DoGemm(batch_size, m, n, k, lhs, rhs, output, - static_cast(config.alpha), - static_cast(config.beta), stream, - algorithm, config.compute_precision, - profile_result, context); - case C128: - return DoGemm( - batch_size, m, n, k, lhs, rhs, output, config.alpha, - static_cast(config.beta), stream, algorithm, - config.compute_precision, profile_result, context); - default: - return InternalError( - "Unexpected GEMM dtype: %s", - primitive_util::LowercasePrimitiveTypeName(output_layout.dtype)); +#define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ + if (operand_types == std::tuple(DataType::ATYPE, DataType::BTYPE, DataType::CTYPE)) { \ + using NativeScaleType = DataType_t; \ + using NativeAType = DataType_t; \ + using NativeCType = DataType_t; \ + return DoGemm( \ + desc.lhs, desc.rhs, desc.output, workspace_buffer, \ + static_cast(config.alpha), \ + static_cast(config.beta), stream, \ + algorithm, config.compute_precision, \ + numeric_options, profile_result, context); \ } -} -namespace cublas_lt { + // if (config.output_layout.dtype == DataType::kInt32) { + // if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo; + // // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // // algorithm) here? Why do we use the default? + // return DoGemmWithAlgorithm( + // desc.lhs, desc.rhs, desc.output, workspace_buffer, + // static_cast(config.alpha.real()), + // static_cast(config.beta), stream, + // *algorithm, se::blas::kDefaultComputePrecision, numeric_options, + // profile_result, context); + // } + + TYPED_GEMM(kFloat, kBF16, kBF16, kBF16) + TYPED_GEMM(kFloat, kHalf, kHalf, kHalf) + // TYPED_GEMM(kFloat, kInt8, kInt8, kFloat) + // TYPED_GEMM(kFloat, kBF16, kBF16, kFloat) + // TYPED_GEMM(kFloat, kHalf, kHalf, kFloat) + TYPED_GEMM(kFloat, kFloat, kFloat, kFloat) + TYPED_GEMM(kDouble, kDouble, kDouble, kDouble) + TYPED_GEMM_COMPLEX(kComplexFloat, kComplexFloat, kComplexFloat, kComplexFloat) + TYPED_GEMM_COMPLEX(kComplexDouble, kComplexDouble, kComplexDouble, kComplexDouble) + +#undef TYPED_GEMM +#undef TYPED_GEMM_COMPLEX + return Internal( + "Unexpected GEMM dtype: %d %d %d", + (int)(config.lhs_layout.dtype), (int)(config.rhs_layout.dtype), + (int)(config.output_layout.dtype)); +} // namespace gpu + +namespace gpublas_lt { -StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue) { +StatusOr EpilogueAddsVectorBias( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: case GemmBackendConfig::RELU: @@ -727,11 +718,12 @@ StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue) { case GemmBackendConfig::BIAS_GELU_AUX: return true; default: - return InternalError("Unknown Epilogue."); + return Internal("Unknown BlasLt::Epilogue."); } } -StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue) { +StatusOr EpilogueHasAuxiliaryOutput( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: case GemmBackendConfig::RELU: @@ -744,312 +736,111 @@ StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue) { case GemmBackendConfig::BIAS_GELU_AUX: return true; default: - return InternalError("Unknown Epilogue."); + return Internal("Unknown BlasLt::Epilogue."); } } -} // namespace cublas_lt - -StatusOr AsBlasDataType(PrimitiveType dtype) { - switch (dtype) { - case F8E5M2: - return se::blas::DataType::kF8E5M2; - case F8E4M3FN: - return se::blas::DataType::kF8E4M3FN; - case S8: - return se::blas::DataType::kInt8; - case F16: - return se::blas::DataType::kHalf; - case BF16: - return se::blas::DataType::kBF16; - case F32: - return se::blas::DataType::kFloat; - case S32: - return se::blas::DataType::kInt32; - case F64: - return se::blas::DataType::kDouble; - case C64: - return se::blas::DataType::kComplexFloat; - case C128: - return se::blas::DataType::kComplexDouble; +StatusOr AsBlasLtEpilogue( + GemmBackendConfig_Epilogue epilogue) { + switch (epilogue) { + case GemmBackendConfig::DEFAULT: + return BlasLt::Epilogue::kDefault; + case GemmBackendConfig::RELU: + return BlasLt::Epilogue::kReLU; + case GemmBackendConfig::GELU: + return BlasLt::Epilogue::kGELU; + case GemmBackendConfig::GELU_AUX: + return BlasLt::Epilogue::kGELUWithAux; + case GemmBackendConfig::BIAS: + return BlasLt::Epilogue::kBias; + case GemmBackendConfig::BIAS_RELU: + return BlasLt::Epilogue::kBiasThenReLU; + case GemmBackendConfig::BIAS_GELU: + return BlasLt::Epilogue::kBiasThenGELU; + case GemmBackendConfig::BIAS_GELU_AUX: + return BlasLt::Epilogue::kBiasThenGELUWithAux; default: - return InternalError("AsBlasDataType: unsupported type"); + return Internal("unexpected epilogue value"); } } -#if GOOGLE_CUDA - -namespace { - -StatusOr AsBlasLtMatrixLayout( - const MatrixLayout& layout) { - TF_ASSIGN_OR_RETURN(se::blas::DataType dtype, AsBlasDataType(layout.dtype)); - - auto order = (layout.order == MatrixLayout::Order::kColumnMajor) - ? se::cuda::BlasLt::MatrixLayout::Order::kColumnMajor - : se::cuda::BlasLt::MatrixLayout::Order::kRowMajor; - - return se::cuda::BlasLt::MatrixLayout::Create( - dtype, layout.num_rows, layout.num_cols, order, layout.batch_size, - layout.leading_dim_stride, layout.batch_stride); -} - -template -struct CudaToNativeT; - -#if CUDA_VERSION >= 11080 -template <> -struct CudaToNativeT { - using type = tsl::float8_e4m3fn; -}; -template <> -struct CudaToNativeT { - using type = tsl::float8_e5m2; -}; -#endif - -template <> -struct CudaToNativeT { - using type = Eigen::bfloat16; -}; -template <> -struct CudaToNativeT { - using type = Eigen::half; -}; -template <> -struct CudaToNativeT { - using type = float; -}; -template <> -struct CudaToNativeT { - using type = double; -}; -template <> -struct CudaToNativeT { - using type = complex64; -}; -template <> -struct CudaToNativeT { - using type = complex128; -}; - -} // namespace - -namespace cublas_lt { - -StatusOr AsBlasLtEpilogue( +StatusOr AsBlasLtEpilogue( mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue) { switch (epilogue) { case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default: - return se::cuda::BlasLt::Epilogue::kDefault; + return BlasLt::Epilogue::kDefault; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Relu: - return se::cuda::BlasLt::Epilogue::kReLU; + return BlasLt::Epilogue::kReLU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Gelu: - return se::cuda::BlasLt::Epilogue::kGELU; + return BlasLt::Epilogue::kGELU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux: - return se::cuda::BlasLt::Epilogue::kGELUWithAux; + return BlasLt::Epilogue::kGELUWithAux; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias: - return se::cuda::BlasLt::Epilogue::kBias; + return BlasLt::Epilogue::kBias; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu: - return se::cuda::BlasLt::Epilogue::kBiasThenReLU; + return BlasLt::Epilogue::kBiasThenReLU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu: - return se::cuda::BlasLt::Epilogue::kBiasThenGELU; + return BlasLt::Epilogue::kBiasThenGELU; case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux: - return se::cuda::BlasLt::Epilogue::kBiasThenGELUWithAux; + return BlasLt::Epilogue::kBiasThenGELUWithAux; } return InternalError("unexpected epilogue value"); } -/*static*/ StatusOr MatmulPlan::From( - const GemmConfig& config, se::cuda::BlasLt::Epilogue epilogue) { - MatrixLayout lhs_layout = config.lhs_layout; - MatrixLayout rhs_layout = config.rhs_layout; - MatrixLayout output_layout = config.output_layout; - MatrixLayout c_layout = config.c_layout; - - // cublasLt matmul requires batch sizes to be equal. If only one operand has a - // batch, the other will be broadcast (as its batch_stride == 0). - size_t batch_size = std::max(lhs_layout.batch_size, rhs_layout.batch_size); - lhs_layout.batch_size = batch_size; - rhs_layout.batch_size = batch_size; - - bool must_swap_operands = - MakeOutputColumnMajor(lhs_layout, rhs_layout, c_layout, output_layout); - - // Do not transopse either input. Note the cuBLASLt documentation somewhat - // incorrectly claims "A must be transposed and B non-transposed" when A and B - // are FP8 (https://docs.nvidia.com/cuda/cublas/#cublasltmatmul). In reality, - // this is only true if A and B are column-major. If A is row-major, A must - // *not* be transposed, and if B is row-major, B must be transposed. We never - // transpose A or B, and expect the caller to ensure A is row-major and B is - // column when A and B are FP8. - const se::blas::Transpose trans_a = se::blas::Transpose::kNoTranspose; - const se::blas::Transpose trans_b = se::blas::Transpose::kNoTranspose; - if (primitive_util::IsF8Type(lhs_layout.dtype) && - lhs_layout.order == MatrixLayout::Order::kColumnMajor) { - return InternalError("The F8 LHS must be column-major"); - } - if (primitive_util::IsF8Type(rhs_layout.dtype) && - rhs_layout.order == MatrixLayout::Order::kRowMajor) { - return InternalError("The F8 RHS must be row-major"); - } +} // namespace gpublas_lt - TF_ASSIGN_OR_RETURN(se::blas::DataType output_dtype, - AsBlasDataType(output_layout.dtype)); - TF_ASSIGN_OR_RETURN( - se::blas::ComputationType computation_type, - GetBlasComputationType(lhs_layout.dtype, output_layout.dtype, - config.compute_precision)); +StatusOr IsMatrixMultiplicationTooSmallForRewriting( + const HloInstruction& dot, int64_t threshold) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); - TF_ASSIGN_OR_RETURN( - se::cuda::BlasLt::MatmulDesc op_desc, - se::cuda::BlasLt::MatmulDesc::Create( - computation_type, GetScaleType(output_dtype, computation_type), - trans_a, trans_b, epilogue)); - - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout a_desc, - AsBlasLtMatrixLayout(lhs_layout)); - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout b_desc, - AsBlasLtMatrixLayout(rhs_layout)); - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout c_desc, - AsBlasLtMatrixLayout(c_layout)); - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout d_desc, - AsBlasLtMatrixLayout(output_layout)); - - return MatmulPlan{ - se::cuda::BlasLt::MatmulPlan{std::move(op_desc), std::move(a_desc), - std::move(b_desc), std::move(c_desc), - std::move(d_desc)}, - config.alpha, config.beta, must_swap_operands}; -} + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers(); -template -Status MatmulPlan::DoMatmul( - se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, se::DeviceMemoryBase bias_buffer, - se::DeviceMemoryBase aux_buffer, se::DeviceMemoryBase a_scale_buffer, - se::DeviceMemoryBase b_scale_buffer, se::DeviceMemoryBase c_scale_buffer, - se::DeviceMemoryBase d_scale_buffer, se::DeviceMemoryBase d_amax_buffer, - const se::cuda::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result) const { - se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream); - TF_RET_CHECK(blas_lt != nullptr); - - Scale alpha; - if constexpr (std::is_same_v || - std::is_same_v) { - alpha = static_cast(alpha_); - } else { - alpha = static_cast(alpha_.real()); + int64_t contracting_size = 1; + for (int64_t dim : dot_dims.lhs_contracting_dimensions()) { + contracting_size *= lhs_shape.dimensions(dim); } - Scale beta = static_cast(beta_); - - se::DeviceMemory output(d_buffer); - return blas_lt->DoMatmul( - stream, plan_, se::HostOrDeviceScalar(alpha), - se::DeviceMemory(a_buffer), se::DeviceMemory(b_buffer), - se::HostOrDeviceScalar(beta), se::DeviceMemory(c_buffer), - output, algorithm, scratch_allocator, se::DeviceMemory(bias_buffer), - aux_buffer, se::DeviceMemory(a_scale_buffer), - se::DeviceMemory(b_scale_buffer), - se::DeviceMemory(c_scale_buffer), - se::DeviceMemory(d_scale_buffer), - se::DeviceMemory(d_amax_buffer), profile_result); -} - -Status MatmulPlan::ExecuteOnStream( - se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, se::DeviceMemoryBase bias_buffer, - se::DeviceMemoryBase aux_buffer, se::DeviceMemoryBase a_scale_buffer, - se::DeviceMemoryBase b_scale_buffer, se::DeviceMemoryBase c_scale_buffer, - se::DeviceMemoryBase d_scale_buffer, se::DeviceMemoryBase d_amax_buffer, - const se::cuda::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result) const { - if (must_swap_operands_) { - std::swap(a_buffer, b_buffer); + TF_ASSIGN_OR_RETURN( + std::vector lhs_non_contracting_dims, + GetNonContractingDims(lhs_shape, dot_dims.lhs_batch_dimensions(), + dot_dims.lhs_contracting_dimensions())); + int64_t lhs_non_contracting_size = 1; + for (int64_t dim : lhs_non_contracting_dims) { + lhs_non_contracting_size *= lhs_shape.dimensions(dim); } - std::tuple - operand_types{plan_.a_desc.type(), plan_.b_desc.type(), - plan_.c_desc.type(), plan_.d_desc.type()}; - -#define TYPED_MATMUL(SCALENTYPE, ATYPE, BTYPE, CTYPE, DTYPE) \ - if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE, DTYPE)) { \ - return DoMatmul::type, \ - CudaToNativeT::type, CudaToNativeT::type, \ - CudaToNativeT::type>( \ - stream, a_buffer, b_buffer, c_buffer, d_buffer, bias_buffer, \ - aux_buffer, a_scale_buffer, b_scale_buffer, c_scale_buffer, \ - d_scale_buffer, d_amax_buffer, algorithm, scratch_allocator, \ - profile_result); \ + TF_ASSIGN_OR_RETURN( + std::vector rhs_non_contracting_dims, + GetNonContractingDims(rhs_shape, dot_dims.rhs_batch_dimensions(), + dot_dims.rhs_contracting_dimensions())); + int64_t rhs_non_contracting_size = 1; + for (int64_t dim : rhs_non_contracting_dims) { + rhs_non_contracting_size *= rhs_shape.dimensions(dim); } -#if CUDA_VERSION >= 11080 - // FP8 compatible type combinations (see cuBLASLt documentation): - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16BF, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16F, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_32F, CUDA_R_32F) - - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16BF, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16BF, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16F, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16F, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E5M2, CUDA_R_32F, CUDA_R_32F) - - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16BF, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16BF, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, - CUDA_R_8F_E4M3) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, - CUDA_R_8F_E5M2) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_32F, CUDA_R_32F) -#endif - - // Other data types: - TYPED_MATMUL(float, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) - TYPED_MATMUL(float, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF) - TYPED_MATMUL(float, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F) - TYPED_MATMUL(double, CUDA_R_64F, CUDA_R_64F, CUDA_R_64F, CUDA_R_64F) - TYPED_MATMUL(complex64, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F) - TYPED_MATMUL(complex128, CUDA_C_64F, CUDA_C_64F, CUDA_C_64F, CUDA_C_64F) - -#undef TYPED_MATMUL - - return InternalError("Unexpected dtype"); + return (rhs_non_contracting_size + lhs_non_contracting_size) * + contracting_size < + threshold; } -StatusOr> -MatmulPlan::GetAlgorithms(se::Stream* stream) const { - se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream); - TF_RET_CHECK(blas_lt != nullptr); - TF_ASSIGN_OR_RETURN(auto preference, - se::cuda::BlasLt::MatmulPreference::Create( - /*max_workspace_size=*/1ll << 32)); // 4GB - return blas_lt->GetMatmulAlgorithms(plan_, preference); -} +bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot) { + // if (!algorithm_util::IsSupportedByElementalIrEmitter( + // dot.precision_config().algorithm())) { + // return false; + // } -} // namespace cublas_lt - -#endif // GOOGLE_CUDA + // Let us be conservative and only throw float dots at the emitters. + switch (dot.shape().element_type()) { + case F16: + case F32: + case BF16: + return true; + default: + return false; + } +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index 6f8662cd565809..dde2d029c40629 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ +#include #include #include +#include +#include #include #include @@ -30,67 +33,91 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#if GOOGLE_CUDA -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" -#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" -#endif // GOOGLE_CUDA - namespace xla { namespace gpu { +// Ordered non-contracting dimensions for a dot instruction operand. StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims); +// Batch dimensions of an operand of a dot instruction. +// Just an unified accessor to lhs_batch_dimensions and rhs_batch_dimensions. +const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( + const HloInstruction& dot, int operand_number); + +// Index of the only contracting dimension of dot instruction operand. +StatusOr ContractingDimensionIndex(const HloInstruction& dot, + int operand_number); + +// Index of the only non-contracting dimension of dot instruction operand. +StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + int operand_number); + // Normalize shape to (batch, rows, columns) logical dimensions. -StatusOr GetBatchRowColumnShape(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); +StatusOr GetBatchRowColumnShape( + const Shape& shape, absl::Span batch_dims, + absl::Span row_dims, absl::Span col_dims); -struct MatrixLayout { - enum class Order { - kRowMajor, // Elements in the same row are contiguous in memory. - kColumnMajor, // Elements in the same column are contiguous in memory. - }; +// GPU folding rule for the `TransposeFolding` pass. +StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, + int64_t operand_idx); + +// Returns true if the sum of the sizes of the unbatched operand matrices +// for the dot is smaller than the given threshold. +StatusOr IsMatrixMultiplicationTooSmallForRewriting( + const HloInstruction& dot, int64_t threshold); + +// Returns true if the backend can lower the dot. Currently the classical +// emitters cannot handle some dots, e.g., i8[] x i8[] -> i32[] dots, +// so we need to always use cuBLAS or Triton for those. +bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot); + +// extending plain MatrixLayout struct with creator functions +struct MatrixLayout : public se::gpu::MatrixLayout { + + MatrixLayout(se::blas::DataType dtype_, int64_t num_rows_, int64_t num_cols_, + Order order_, int64_t batch_size_ = 1, + absl::optional leading_dim_stride_ = {}, + absl::optional batch_stride_ = {}, + absl::optional transpose_ = {}) : + se::gpu::MatrixLayout(dtype_, num_rows_, num_cols_, + order_, batch_size_, leading_dim_stride_, + batch_stride_, transpose_) {} // Returns the matrix layout for a logical shape (batch, rows, columns). static StatusOr For(const Shape& shape); // Returns the matrix layout with the given batch, row, col dimensions. static StatusOr For(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); + absl::Span batch_dims, + absl::Span row_dims, + absl::Span col_dims); // Returns the matrix layout for the output. static StatusOr For(const Shape& shape, - size_t lhs_num_batch_dims, - size_t lhs_num_row_dims, - size_t rhs_num_batch_dims, - size_t rhs_num_col_dims); - - void Transpose(); - - PrimitiveType dtype; - // `num_rows` / `num_cols` are for the "logical" matrix shape: - // i.e. the contracting dim has size `num_cols` for LHS operands and - // `num_rows` for RHS operands. - int64_t num_rows; - int64_t num_cols; - Order order; - int64_t leading_dim_stride; - int64_t batch_size; - int64_t batch_stride; // `batch_stride` is set to `0` when `batch_size == 1`. + size_t lhs_num_batch_dims, + size_t lhs_num_row_dims, + size_t rhs_num_batch_dims, + size_t rhs_num_col_dims); }; -// GPU folding rule for the `TransposeFolding` pass. -StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, - int64_t operand_idx); +struct GemmConfig : public se::gpu::GemmConfig { + // For legacy Gemm operations XLA:GPU allocates its own workspace and passes + // it to all BLAS API calls. + // + // Size of the workspace based on NVIDIA recommendation: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + static constexpr int64_t kHopperWorkspace = 32 * 1024 * 1024; // 32 MiB + static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024; // 4 MiB + + explicit GemmConfig(const se::gpu::GemmConfig& cfg) : + se::gpu::GemmConfig(cfg) { } -struct GemmConfig { static StatusOr For(const HloInstruction* gemm); static StatusOr For(mlir::lmhlo_gpu::GEMMOp op); + static StatusOr For(mlir::lmhlo_gpu::CublasLtMatmulOp op); static StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, @@ -98,157 +125,45 @@ struct GemmConfig { absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision, - bool grad_x, bool grad_y); - - // As above with additional `c_shape` parameter. - static StatusOr For( - const Shape& lhs_shape, absl::Span lhs_batch_dims, - absl::Span lhs_contracting_dims, const Shape& rhs_shape, - absl::Span rhs_batch_dims, - absl::Span rhs_contracting_dims, const Shape& c_shape, - const Shape& output_shape, double alpha_real, double alpha_imag, - double beta, std::optional algorithm, int64_t compute_precision, - bool grad_x, bool grad_y); - - MatrixLayout lhs_layout; - MatrixLayout rhs_layout; - MatrixLayout c_layout; - MatrixLayout output_layout; - complex128 alpha; - double beta; - std::optional algorithm; - int64_t compute_precision; - bool grad_x, grad_y; + int64_t algorithm, int64_t compute_precision, + se::gpu::BlasLt::Epilogue epilogue); + + struct DescriptorsTuple { + se::gpu::MatrixDescriptor lhs; + se::gpu::MatrixDescriptor rhs; + se::gpu::OutputMatrixDescriptor output; + bool operands_swapped; + }; + StatusOr GetMatrixDescriptors( + se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, + se::DeviceMemoryBase out_buf) const; }; -StatusOr GetBlasComputationType( - PrimitiveType lhs_dtype, PrimitiveType output_dtype, - int64_t compute_precision); - -namespace cublas_lt { - -// Returns the type for the alpha and beta scalars. -se::blas::DataType GetScaleType(se::blas::DataType c_type, - se::blas::ComputationType computation_type); - -} // namespace cublas_lt - // Run the given GEMM instruction `gemm` subject to the configuration // in `gemm_config` and the passed buffers. // // If `algorithm` is provided, it overrides the one specified in `config`. -Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, se::Stream* stream, - std::optional algorithm = std::nullopt, - se::blas::ProfileResult* profile_result = nullptr); - -namespace cublas_lt { - -StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue); -StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue); - -} // namespace cublas_lt - -StatusOr AsBlasDataType(PrimitiveType dtype); - -#if GOOGLE_CUDA - -namespace cublas_lt { - -StatusOr AsBlasLtEpilogue( +Status RunGemm( + const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, + se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase workspace_buffer, bool deterministic_ops, + se::Stream* stream, + std::optional algorithm = std::nullopt, + se::blas::ProfileResult* profile_result = nullptr); + +namespace gpublas_lt { + +StatusOr EpilogueAddsVectorBias( + GemmBackendConfig_Epilogue epilogue); +StatusOr EpilogueHasAuxiliaryOutput( + GemmBackendConfig_Epilogue epilogue); + +StatusOr AsBlasLtEpilogue( + GemmBackendConfig_Epilogue epilogue); +StatusOr AsBlasLtEpilogue( mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue); -class MatmulPlan { - public: - template ::value || - std::is_same::value>> - static StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - TF_ASSIGN_OR_RETURN( - GemmConfig config, - GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - GetShape(op.getD()), op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision)); - - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue, - AsBlasLtEpilogue(op.getEpilogue())); - return From(config, epilogue); - } - - static StatusOr From(const GemmConfig& config, - se::cuda::BlasLt::Epilogue epilogue); - - Status ExecuteOnStream( - se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, - se::DeviceMemoryBase bias_buffer, // may be null - se::DeviceMemoryBase aux_buffer, // may be null - se::DeviceMemoryBase a_scale_buffer, se::DeviceMemoryBase b_scale_buffer, - se::DeviceMemoryBase c_scale_buffer, se::DeviceMemoryBase d_scale_buffer, - se::DeviceMemoryBase d_amax_buffer, - const se::cuda::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result = nullptr) const; - - StatusOr> GetAlgorithms( - se::Stream* stream) const; - - private: - MatmulPlan(se::cuda::BlasLt::MatmulPlan plan, complex128 alpha, double beta, - bool must_swap_operands) - : plan_(std::move(plan)), - alpha_(alpha), - beta_(beta), - must_swap_operands_(must_swap_operands) {} - - template - Status DoMatmul(se::Stream* stream, se::DeviceMemoryBase a_buffer, - se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, - se::DeviceMemoryBase d_buffer, - se::DeviceMemoryBase bias_buffer, // may be null - se::DeviceMemoryBase aux_buffer, // may be null - se::DeviceMemoryBase a_scale, se::DeviceMemoryBase b_scale, - se::DeviceMemoryBase c_scale, se::DeviceMemoryBase d_scale, - se::DeviceMemoryBase d_amax, - const se::cuda::BlasLt::MatmulAlgorithm& algorithm, - se::ScratchAllocator& scratch_allocator, - se::blas::ProfileResult* profile_result) const; - - se::cuda::BlasLt::MatmulPlan plan_; - complex128 alpha_; - double beta_; - bool must_swap_operands_; -}; - -} // namespace cublas_lt - -#endif // GOOGLE_CUDA +} // namespace gpublas_lt } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index b8b530d9a19526..665c48be52d93b 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -73,18 +73,17 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:gpu_asm_opts_util", "//tensorflow/compiler/xla/service/gpu:gpu_conv_algorithm_picker", "//tensorflow/compiler/xla/service/gpu:gpu_conv_runner", - "//tensorflow/compiler/xla/service/gpu:gpu_serializable_autotuner", + "//tensorflow/compiler/xla/service/gpu:autotuner_util", "//tensorflow/compiler/xla/service/gpu:non_atomically_upgradeable_rw_lock", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", + "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:attribute_exporter", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", - ]), + ], ) cc_library( diff --git a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc index 06901e21768bbe..7dc9ca1d8ca83f 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_serializable_autotuner.h" #include "tensorflow/compiler/xla/service/gpu/non_atomically_upgradeable_rw_lock.h" #include "tensorflow/compiler/xla/service/gpu/runtime/support.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" @@ -394,8 +393,12 @@ static absl::Status ConvImpl( auto stream_exec = run_options->stream()->parent(); auto allocator = run_options->allocator(); - DeviceConfig device_config = {stream_exec, allocator}; - GpuConvAlgorithmPicker conv_algorithm_picker(device_config); + + CHECK(stream_exec != nullptr); + AutotuneConfig autotune_config{DeviceConfig{stream_exec, allocator}, + *debug_options}; + + GpuConvAlgorithmPicker conv_algorithm_picker(autotune_config); GpuConvConfig gpu_conv_config = conv.value()->config; auto autotune_result = diff --git a/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc b/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc index 3b6de30acf58d0..e9507c3130d2b9 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc @@ -72,7 +72,7 @@ void PopulateCublasLtMatmulAttrEncoding(CustomCallAttrEncodingSet& encoding) { return cublas_lt::AsBlasLtEpilogue(value).value(); }); } - +34f34f //===----------------------------------------------------------------------===// // cuBLASLt matmul custom call implementation. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc b/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc index 788ad2a376be44..bec4604f7d3d29 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc @@ -84,8 +84,9 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, // we pass a non-null ProfileResult, DoGemmWithAlgorithm should // always return true, and the actual success-ness is returned in // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, stream, algorithm, - &profile_result)); + se::DeviceMemoryBase workspace{}; + TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, workspace, false, + stream, algorithm, &profile_result)); return std::move(profile_result); })); @@ -147,8 +148,10 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, #endif } + se::DeviceMemoryBase workspace{}; Status executed = - RunGemm(*gemm_config, lhs_data, rhs_data, output_data, stream); + RunGemm(*gemm_config, lhs_data, rhs_data, output_data, workspace, false, + stream); if (!executed.ok()) return ToAbslStatus(executed); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/support.h b/tensorflow/compiler/xla/service/gpu/runtime/support.h index 57767c32adb208..edeb23e71cd188 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/support.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/support.h @@ -100,7 +100,8 @@ inline StatusOr GetGemmConfig( return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs), rhs_batch, rhs_contract, ToShape(out), alpha_real, alpha_imag, beta, algorithm, - se::blas::kDefaultComputePrecision, grad_x, grad_y); + se::blas::kDefaultComputePrecision, + se::gpu::BlasLt::Epilogue::kDefault); } // adds Dot Dimension Attribute encodings for calls to Gemm and cuBLASLt diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index bd0fc2cae1801e..550a48e5be1702 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -37,14 +37,14 @@ limitations under the License. namespace xla { namespace gpu { -namespace { - using se::dnn::DataLayout; using se::dnn::DataLayoutString; using se::dnn::FilterLayout; using se::dnn::FilterLayoutString; using tensorflow::AutotuneResult; +namespace { + // Returns the smallest integer >= 0 that's not in the given set of numbers. // // For example, FindMissingDnum({1, 0, 3, 4}) returns 2. @@ -94,7 +94,107 @@ StatusOr DataLayoutToXlaLayout( return LayoutUtil::MakeLayoutFromMajorToMinor(layout); } -} // anonymous namespace +std::vector KeepNonFailures( + absl::Span profile_results) { + // Filter out all failures except WRONG_RESULT, because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). Other failures can be detected with high accuracy. E.g. + // REDZONE_MODIFIED which is also quite severe. + std::vector filtered_results; + absl::c_copy_if(profile_results, std::back_inserter(filtered_results), + [](const AutotuneResult& r) { + return !r.has_failure() || + r.failure().kind() == AutotuneResult::WRONG_RESULT; + }); + return filtered_results; +} + +Status AllAlgorithmsFailedInternalError( + absl::optional instr_str, + absl::Span profile_results) { + std::ostringstream msg; + if (instr_str.has_value()) { + msg << "All algorithms tried for " << instr_str.value() + << " failed. Falling back to default algorithm. Per-algorithm " + "errors:"; + } else { + msg << "All algorithms failed. Falling back to the default algorithm. " + << "Per-algorithm errors:"; + } + for (const auto& result : profile_results) { + msg << "\n " << result.failure().msg(); + } + return Internal("%s", msg.str()); +} + +Status NoAlgorithmSuppliedInternalError( + absl::optional instr_str) { + std::ostringstream msg; + if (instr_str.has_value()) { + msg << "There are no algorithm candidates for computing: \n " + << instr_str.value() + << "\nThis likely means that the instruction shape is not supported by " + "the target GPU library."; + } else { + msg << "There are no algorithm candidates for computing the instruction.\n" + "This likely means that the instruction shape is not supported by " + "the target GPU library."; + } + return Internal("%s", msg.str()); +} + +void SortAutotuningResultsByRunTime(std::vector& results) { + absl::c_sort(results, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return tsl::proto_utils::FromDurationProto(lhs.run_time()) < + tsl::proto_utils::FromDurationProto(rhs.run_time()); + }); +} + +absl::Span TopResultsWithinMeasurementError( + std::vector& results_sorted_by_runtime) { + // This value was picked by repeatedly running a few kernels that run for a + // short time and observing the run-time variance. A more rigorous analysis + // of the measurement error might yield a better error threshold. + constexpr absl::Duration kMeasurementError = absl::Microseconds(4); + + absl::Duration min_time = tsl::proto_utils::FromDurationProto( + results_sorted_by_runtime.front().run_time()); + absl::Duration limit_time = min_time + kMeasurementError; + + auto limit_time_it = absl::c_find_if( + results_sorted_by_runtime, [limit_time](const AutotuneResult& x) { + return tsl::proto_utils::FromDurationProto(x.run_time()) > limit_time; + }); + return absl::MakeSpan(&*results_sorted_by_runtime.begin(), &*limit_time_it); +} +} // anonymous namespace + +StatusOr PickBestResult( + absl::Span profile_results, + absl::optional instr_str) { + if (profile_results.empty()) { + return NoAlgorithmSuppliedInternalError(instr_str); + } + + std::vector filtered_results = + KeepNonFailures(profile_results); + + if (filtered_results.empty()) { + return AllAlgorithmsFailedInternalError(instr_str, profile_results); + } + + // Kernel run-time measurements within kMeasurementError are not precise. + // Consider the lowest measurements within the error margin as equivalent and + // within them prefer algorithms that use the least amount of scratch memory. + SortAutotuningResultsByRunTime(filtered_results); + auto top_within_error = TopResultsWithinMeasurementError(filtered_results); + return *absl::c_min_element(top_within_error, [](const AutotuneResult& lhs, + const AutotuneResult& rhs) { + return lhs.scratch_bytes() < rhs.scratch_bytes(); + }); +} + StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, @@ -374,7 +474,6 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, *kernel_args); } - // Unimplemented for integers yet. template typename std::enable_if::value, @@ -395,38 +494,93 @@ static void InitializeTypedBuffer(se::Stream* stream, int64_t* rng_state) { // Accesses to static variables are not locked, since the caller is already // in a critical section. + + // Use a large prime number to fragment the accesses. + constexpr int host_buffer_size = 10069; + static std::vector* host_buffer = [] { // Use a large prime number to fragment the accesses. - auto* ret = new std::vector(10069); + auto* ret = new std::vector(host_buffer_size); // Default-seeded random numbers. std::mt19937 gen; for (auto& element : *ret) { + constexpr bool kIsIntegral = std::numeric_limits::is_integer; + constexpr bool kIsLowRange = + !kIsIntegral && std::numeric_limits::max_exponent <= + std::numeric_limits::max_exponent; // Only double gets random values in double. Other data types get random // values in float then cast them to the target data types. - using RandomFloatingPointType = - typename std::conditional::value || - std::is_same::value, - float, T>::type; - using RandomType = - typename std::conditional::value, float, - RandomFloatingPointType>::type; + using RandomType = typename std::conditional::value, + double, float>::type; // Scale down the values for fp16 to have less overflows. - auto upper_bound = - RandomType(std::is_same::value ? 0.1 : 1.0); + auto upper_bound = RandomType(kIsLowRange ? 0.1 : 1.0); auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen); // For bf16, float or double, it is between [0,1]. // For fp16, it ranges between [0, 0.1]. // For integer types, element is either 0 or 1 for less overflows // especially for int8_t. - element = T(std::is_integral::value ? rand_val + 0.5 : rand_val); + element = T(kIsIntegral ? rand_val + 0.5 : rand_val); } return ret; }(); + // The buffer of random numbers is treated as being circular, and the seed in + // *rng_state is the offset in host_buffer that is copied to the zeroth index + // on the device. For large buffers then repeatedly copying the data from the + // host is expensive, so we just copy it once and use a kernel to repeat the + // data as needed. +#ifdef GOOGLE_CUDA + CHECK_EQ(0, buffer.size() % sizeof(T)); + int64_t elements_to_fill = buffer.size() / sizeof(T); + int64_t host_index = *rng_state; + CHECK_LT(host_index, host_buffer_size); + *rng_state = (*rng_state + elements_to_fill) % host_buffer_size; + // Copy the last part of `host_buffer` to the start of `buf` on the device + int64_t first_size = + std::min(host_buffer_size - host_index, elements_to_fill); + stream->ThenMemcpy(&buffer, host_buffer->data() + host_index, + first_size * sizeof(T)); + elements_to_fill -= first_size; + if (elements_to_fill == 0) { + // Nothing more to do + return; + } + // Issue a second host->device copy to transfer the rest of host_buffer + int64_t second_size = std::min(host_index, elements_to_fill); + CHECK_LE(first_size + second_size, host_buffer_size); + // = buffer.GetByteSlice(first_size * sizeof(T), second_size * sizeof(T)); + se::DeviceMemoryBase mem( static_cast< uint8_t *>(buffer.opaque()) + + first_size * sizeof(T), second_size * sizeof(T)); + + stream->ThenMemcpy(&mem, host_buffer->data(), mem.size()); + elements_to_fill -= second_size; + if (elements_to_fill == 0) { + // Nothing more to do + return; + } + // Repeat the host_buffer_size elements at the start of `buf` to the end + CHECK_EQ(elements_to_fill, buffer.size() / sizeof(T) - host_buffer_size); + se::StreamExecutor* executor = stream->parent(); + auto kernel = + se::TypedKernelFactory::Create( + executor, "RepeatBufferKernel", repeat_buffer_kernel::kernel()); + if (!kernel.ok()) { + LOG(FATAL) << "Could not create RepeatBufferKernel: " << kernel.status(); + } + // Launch the kernel with at least host_buffer_bytes threads. Each thread + // will read one byte of `host_buffer` from the start of `buffer`, where the + // Memcpy call(s) above put it, and scatter it through the rest of `buffer`. + constexpr int64_t host_buffer_bytes = host_buffer_size * sizeof(T); + constexpr int threads_per_block = 256; + constexpr int blocks_per_grid = + (host_buffer_bytes + threads_per_block - 1) / threads_per_block; + TF_CHECK_OK(stream->ThenLaunch(se::ThreadDim(threads_per_block, 1, 1), + se::BlockDim(blocks_per_grid, 1, 1), *kernel, + buffer, host_buffer_bytes, + static_cast(buffer.size()))); +#else // GOOGLE_CUDA int64_t& host_index = *rng_state; - char* current_addr = static_cast(buffer.opaque()); - CHECK_EQ(0, buffer.size() % sizeof(T)); int64_t elements_left = buffer.size() / sizeof(T); while (elements_left > 0) { CHECK_LE(host_index, host_buffer->size()); @@ -437,15 +591,16 @@ static void InitializeTypedBuffer(se::Stream* stream, std::min(host_buffer->size() - host_index, elements_left); se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T)); stream->ThenMemcpy(&mem, host_buffer->data() + host_index, - elements_copied * sizeof(T)); + elements_copied * sizeof(T)); current_addr += elements_copied * sizeof(T); elements_left -= elements_copied; host_index += elements_copied; } +#endif // GOOGLE_CUDA } void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, - int64_t* rng_state, se::DeviceMemoryBase buffer) { + int64_t* rng_state, se::DeviceMemoryBase buffer) { switch (buffer_type) { case xla::F16: return InitializeTypedBuffer(stream, buffer, rng_state); @@ -457,16 +612,12 @@ void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, case xla::F64: case xla::C128: return InitializeTypedBuffer(stream, buffer, rng_state); - case xla::PRED: - // Using S8 for PRED initialization, as vector has different - // semantics and cannot be used as a buffer. case xla::S8: return InitializeTypedBuffer(stream, buffer, rng_state); - case xla::S32: - return InitializeTypedBuffer(stream, buffer, rng_state); + case xla::U8: + return InitializeTypedBuffer(stream, buffer, rng_state); default: - LOG(FATAL) << "Unexpected type: " - << primitive_util::LowercasePrimitiveTypeName(buffer_type); + LOG(FATAL) << "Unexpected type: " << PrimitiveType_Name(buffer_type); } } @@ -521,50 +672,5 @@ bool RequireDeterminism(const HloModuleConfig& config) { config.debug_options().xla_gpu_deterministic_ops(); } -StatusOr PickBestResult( - absl::Span profile_results, - std::optional instr_str, - HloModuleConfig hlo_module_config) { - std::vector filtered_results; - - // For now, we ignore WRONG_RESULT failures because false-positives are - // possible (e.g. perhaps the reference algorithm is the one that's - // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're - // quite severe and can be detected with high accuracy. - absl::c_copy_if( - profile_results, std::back_inserter(filtered_results), - [](const AutotuneResult& r) { - return !(r.has_failure() && - r.failure().kind() != AutotuneResult::WRONG_RESULT); - }); - - if (filtered_results.empty()) { - std::ostringstream msg; - if (instr_str.has_value()) { - msg << "All algorithms tried for " << instr_str.value() - << " failed. Falling back to default algorithm. Per-algorithm " - "errors:"; - } else { - msg << "All algorithms failed. Falling back to the default algorithm. " - << "Per-algorithm errors:"; - } - for (const auto& result : profile_results) { - msg << "\n " << result.failure().msg(); - } - return InternalError("%s", msg.str()); - } - - auto selected_result = filtered_results.begin(); - if (!RequireDeterminism(hlo_module_config)) { - selected_result = absl::c_min_element( - filtered_results, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return tsl::proto_utils::FromDurationProto(lhs.run_time()) < - tsl::proto_utils::FromDurationProto(rhs.run_time()); - }); - } - return *selected_result; -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index a1b4d2219aecac..2b03cf8d9c12f1 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -107,8 +107,7 @@ StatusOr GetDNNDataTypeFromPrimitiveType(PrimitiveType type); // If deterministic output is requested, returns first (not failing) result. StatusOr PickBestResult( absl::Span profile_results, - std::optional instr_str, - HloModuleConfig hlo_module_config); + absl::optional instr_str); // Returns whether determinism is required. bool RequireDeterminism(const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 13de65ab562cdd..bf700c28b40c07 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -69,6 +69,21 @@ cc_library( ], ) +tf_cc_test( + name = "gpu_hlo_runner_test", + srcs = ["gpu_hlo_runner_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "element_wise_row_vectorization_test", srcs = ["element_wise_row_vectorization_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_hlo_runner_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_hlo_runner_test.cc new file mode 100644 index 00000000000000..61e1acd1b2268f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_hlo_runner_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_comparison.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace gpu { + +template +std::vector MakePointerVector(std::vector& input_vec) { + std::vector output_pointers; + output_pointers.reserve(input_vec.size()); + for (auto& input : input_vec) { + output_pointers.push_back(&input); + } + return output_pointers; +} + + +class HloRunnerTest : public GpuCodegenTest {}; + +TEST_F(HloRunnerTest, RunSingle) { + + std::ifstream ifs("input.hlo"); + ASSERT_TRUE(ifs.good()); + + std::stringstream buffer; + buffer << ifs.rdbuf(); + + HloModuleConfig config = GetModuleConfigForTest(); +#if 1 + //config.set_num_partitions(8); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(buffer.str(), + config)); + + auto ref_module = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(auto exec, test_runner_.CreateExecutable(std::move(module), true)); + + VLOG(0) << "Creating fake args.."; + TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, xla::MakeFakeArguments(ref_module.get(), + true, /*pseudo-random*/ + false /* use large range*/)); + auto arg_ptrs = MakePointerVector(fake_arguments); + + auto& ref_runner = HloTestBase::reference_runner_; + TF_ASSERT_OK_AND_ASSIGN( + auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true)); + + // TF_ASSERT_OK_AND_ASSIGN(auto truth, + // ReadLiteralFromProto("/tf/xla/expected.pb")); + // TF_ASSERT_OK_AND_ASSIGN(auto truth, + // ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs, nullptr)); + // WriteLiteralToTempFile(truth, "expected"); + //VLOG(0) << "Got expected literal from file.. running test"; + + TF_ASSERT_OK_AND_ASSIGN( + auto test_res, test_runner_.ExecuteWithExecutable(exec.get(), arg_ptrs)); + + VLOG(0) << "Running reference exec.."; + TF_ASSERT_OK_AND_ASSIGN( + auto truth, ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs)); + + ErrorSpec error_spec{1e-2, 1e-3}; + //ErrorSpec error_spec(1e-5 /*abs*/, 1e-5 /*rel*/); + ASSERT_EQ(literal_comparison::Near(/*expected=*/truth, + /*actual=*/test_res, + /*error=*/error_spec, + /*detailed_message=*/true, {}), OkStatus()); + + // EXPECT_TRUE(RunAndCompare(std::move(module), + // // absl::Span< xla::Literal * const>(arg_ptrs.data(), arg_ptrs.size()), error_spec)); +#else + int NumReplicas = 8, NumParts = 1; + config.set_replica_count(NumReplicas); + config.set_num_partitions(NumParts); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(buffer.str(), config)); + DeviceAssignment assn(/*replica_count=*/NumReplicas, + /*computation_count=*/NumParts); + for (int64_t i = 0, k = 0; i < NumReplicas; i++) + for (int64_t j = 0; j < NumParts; j++) { + assn(i, j) = k++; + } + + auto fake_arguments = xla::MakeFakeArguments( + module.get(), + true, /*pseudo-random*/ + false /* use large range*/).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(auto exec, + test_runner_.CreateExecutable(std::move(module), true)); + + for(int i = 0; i < 10; i++) { + VLOG(0) << "Running iteration #" << i; + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + HloTestBase::ExecuteReplicated( + [&](int64_t){ return exec.get(); }, + [&fake_arguments](int64_t replica_id) + { return fake_arguments.size(); }, + [&fake_arguments](int64_t replica_id, int64_t idx) + { return &fake_arguments[idx]; }, + NumReplicas, false /*run hlo*/, &assn)); + ASSERT_EQ(results.size(), NumReplicas); + } +#endif +} + +} // namespace gpu +} // namespace xla + \ No newline at end of file diff --git a/tensorflow/compiler/xla/stream_executor/BUILD b/tensorflow/compiler/xla/stream_executor/BUILD index 0425da4aea423a..34f7034934856a 100644 --- a/tensorflow/compiler/xla/stream_executor/BUILD +++ b/tensorflow/compiler/xla/stream_executor/BUILD @@ -450,6 +450,7 @@ tsl_gpu_library( ":temporary_memory_manager", ":timer", "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt_gemm_runner", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", diff --git a/tensorflow/compiler/xla/stream_executor/blas.cc b/tensorflow/compiler/xla/stream_executor/blas.cc index 5eeb0fe38c0ff2..4a4f0af1ab2df3 100644 --- a/tensorflow/compiler/xla/stream_executor/blas.cc +++ b/tensorflow/compiler/xla/stream_executor/blas.cc @@ -22,6 +22,27 @@ limitations under the License. namespace stream_executor { namespace blas { +// TODO(ezhulenev): We need a scoped thread local map-like container to make +// sure that we can have multiple BlasSupport instances that do not overwrite +// each others workspaces. For not it's ok as we know that this can't happen. +static thread_local DeviceMemoryBase* workspace_thread_local = nullptr; + +BlasSupport::ScopedWorkspace::ScopedWorkspace(BlasSupport* blas, + DeviceMemoryBase* workspace) + : blas_(blas) { + blas->SetWorkspace(workspace); +} + +BlasSupport::ScopedWorkspace::~ScopedWorkspace() { blas_->ResetWorkspace(); } + +DeviceMemoryBase* BlasSupport::GetWorkspace() { return workspace_thread_local; } + +void BlasSupport::SetWorkspace(DeviceMemoryBase* workspace) { + workspace_thread_local = workspace; +} + +void BlasSupport::ResetWorkspace() { workspace_thread_local = nullptr; } + std::string TransposeString(Transpose t) { switch (t) { case Transpose::kNoTranspose: diff --git a/tensorflow/compiler/xla/stream_executor/blas.h b/tensorflow/compiler/xla/stream_executor/blas.h index b5e025acf6eac6..fec02e0ab7511a 100644 --- a/tensorflow/compiler/xla/stream_executor/blas.h +++ b/tensorflow/compiler/xla/stream_executor/blas.h @@ -56,6 +56,10 @@ struct half; namespace stream_executor { +namespace gpu { +struct BlasLt; +} // namespace gpu + class Stream; class ScratchAllocator; @@ -208,6 +212,7 @@ constexpr ComputePrecision kDefaultComputePrecision = 0; class BlasSupport { public: virtual ~BlasSupport() {} + virtual gpu::BlasLt *GetBlasLt() = 0; // Performs a BLAS y <- ax+y operation. virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, @@ -494,11 +499,41 @@ class BlasSupport { int ldb, int batch_count) = 0; virtual tsl::Status GetVersion(std::string *version) = 0; + // TODO(ezhulenev): We should never pass ScratchAllocator to any of the APIs + // in this file, because it makes them incompatible with command buffers (CUDA + // graphs). We should pass workspace memory explicitly to all APIs. However + // this is a giant change, so currently we work around it by setting a thread + // local workspace and rely on `ScopedBlasWorkspace` RAII helper to reset it. + // + // APIs that get ScratchAllocator ignore this workspace, and continue + // allocating scratch memory on demand. + class ScopedWorkspace { + public: + ScopedWorkspace(BlasSupport *blas, DeviceMemoryBase *workspace); + ~ScopedWorkspace(); + + private: + BlasSupport *blas_; + }; protected: + DeviceMemoryBase *GetWorkspace(); BlasSupport() {} private: + // Workspace memory pointer is thread local, once it is set all Blas + // operations issued from a caller thread might use it if it has large enough + // size. It's a user responsibility to make sure that workspace will outlive + // all issued BLAS operations. + // + // TODO(ezhulenev): This is a giant footgun! We have to remove it and use + // explicit workspace memory argument for all BLAS operations. + void SetWorkspace(DeviceMemoryBase *workspace); + + // Resets user-defined workspace memory, so that Blas operations can use their + // own memory pool for allocating workspace. + void ResetWorkspace(); + SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); }; diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index 8843902db193ec..96f560cc1181be 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -1,13 +1,16 @@ # Description: # GPU-platform specific StreamExecutor support code. - +load( + "//tensorflow:tensorflow.bzl", + "tf_gpu_kernel_library", +) load( "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_gpu_is_configured", ) load( "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", + "if_rocm_is_configured", "rocm_copts" ) load( "//tensorflow/tsl:tsl.bzl", @@ -62,6 +65,39 @@ cc_library( ]), ) +cc_library( + name = "gpu_blas_lt", + srcs = if_gpu_is_configured(["gpu_blas_lt.cc"]), + hdrs = if_gpu_is_configured(["gpu_blas_lt.h"]), + deps = if_gpu_is_configured([ + "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:shape_util", + #"//tensorflow/core/platform:env", + "//tensorflow/tsl/util:env_var", + "@com_google_absl//absl/types:any", + ]), +) + + +cc_library( + name = "gpu_blas_lt_gemm_runner", + srcs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.cc"]), + hdrs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.h"]), + deps = if_gpu_is_configured([ + "//tensorflow/core/protobuf:autotuning_proto_cc", + "//tensorflow/compiler/xla:autotune_results_proto_cc", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/stream_executor:scratch_allocator", + "//tensorflow/compiler/xla/service/gpu:autotuner_util", + "//tensorflow/compiler/xla:debug_options_flags", + ":gpu_blas_lt", + ]), +) + cc_library( name = "gpu_diagnostics_header", hdrs = if_gpu_is_configured(["gpu_diagnostics.h"]), @@ -174,7 +210,7 @@ tsl_gpu_library( "//tensorflow/compiler/tf2xla:__subpackages__", "//tensorflow/compiler/xla:__subpackages__", "//tensorflow/core/common_runtime/gpu:__subpackages__", - "//tensorflow/stream_executor:__subpackages__", + "//tensorflow/compiler/xla/stream_executor:__subpackages__", ]), deps = [ "//tensorflow/compiler/xla/stream_executor:multi_platform_manager", @@ -362,33 +398,65 @@ cc_library( ]) + ["//tensorflow/tsl/platform:statusor"], ) +# cc_library( +# name = "redzone_allocator", +# srcs = if_gpu_is_configured(["redzone_allocator.cc"]), +# hdrs = if_gpu_is_configured(["redzone_allocator.h"]), +# copts = tsl_copts(), +# visibility = set_external_visibility([ +# "//tensorflow/compiler/xla/service/gpu:__subpackages__", +# "//tensorflow/compiler/xla/stream_executor:__subpackages__", +# "//tensorflow/core/kernels:__subpackages__", +# ]), +# deps = if_gpu_is_configured([ +# ":asm_compiler", +# ":gpu_asm_opts", +# "@com_google_absl//absl/base", +# "@com_google_absl//absl/container:fixed_array", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings:str_format", +# "@com_google_absl//absl/types:optional", +# "//tensorflow/tsl/lib/math:math_util", +# "//tensorflow/tsl/platform:errors", +# "//tensorflow/tsl/platform:logging", +# "//tensorflow/tsl/framework:allocator", +# "//tensorflow/compiler/xla/stream_executor:device_memory", +# "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", +# "//tensorflow/compiler/xla/stream_executor:scratch_allocator", +# "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", +# "//tensorflow/tsl/platform:status", +# ]), +# ) + +tf_gpu_kernel_library( + name = "redzone_allocator_kernel", + hdrs = if_gpu_is_configured(["redzone_allocator_kernel.h"]), + srcs = if_cuda_is_configured(["redzone_allocator_kernel_cuda.cc"]) + + if_rocm_is_configured(["redzone_allocator_kernel_rocm.cu.cc"]), + deps = [":gpu_asm_opts"], +) + cc_library( name = "redzone_allocator", - srcs = if_gpu_is_configured(["redzone_allocator.cc"]), hdrs = if_gpu_is_configured(["redzone_allocator.h"]), - copts = tsl_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/service/gpu:__subpackages__", - "//tensorflow/compiler/xla/stream_executor:__subpackages__", - "//tensorflow/core/kernels:__subpackages__", + srcs = if_gpu_is_configured(["redzone_allocator.cc"]), + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "-DTENSORFLOW_USE_ROCM=1", ]), - deps = if_gpu_is_configured([ - ":asm_compiler", - ":gpu_asm_opts", - "@com_google_absl//absl/base", + visibility = ["//visibility:public"], + deps = [ "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "//tensorflow/tsl/lib/math:math_util", - "//tensorflow/tsl/platform:errors", - "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/framework:allocator", + "@com_google_absl//absl/strings", + ":redzone_allocator_kernel", + ":gpu_asm_opts", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "//tensorflow/compiler/xla/stream_executor:scratch_allocator", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", - "//tensorflow/tsl/platform:status", + ] + if_cuda_is_configured([ + "//tensorflow/stream_executor/cuda:ptxas_utils", ]), ) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc new file mode 100644 index 00000000000000..749ffc2f35a82a --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -0,0 +1,285 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/util/env_var.h" + +namespace stream_executor { + +namespace gpu { + +using blas::ComputationType; +using blas::DataType; +using xla::PrimitiveType; + +bool GpuBlasLtEnabled() { + static std::atomic_bool result{[] { + bool value = false; + tsl::ReadBoolFromEnvVar("TF_ENABLE_GPU_BLASLT", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + +namespace { + +bool TF32_Enabled() { + static std::atomic_bool result{[] { + bool value = false; + (void)tsl::ReadBoolFromEnvVar("ROCM_XF32", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + +bool Fast_16F_Enabled() { + static std::atomic_bool result{[] { + bool value = false; + (void)tsl::ReadBoolFromEnvVar("ROCM_FAST_16F", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + +} // namespace + +xla::StatusOr AsBlasDataType(PrimitiveType dtype) { + switch (dtype) { + case PrimitiveType::S8: + return DataType::kInt8; + case PrimitiveType::F16: + return DataType::kHalf; + case PrimitiveType::BF16: + return DataType::kBF16; + case PrimitiveType::F32: + return DataType::kFloat; + case PrimitiveType::S32: + return DataType::kInt32; + case PrimitiveType::F64: + return DataType::kDouble; + case PrimitiveType::C64: + return DataType::kComplexFloat; + case PrimitiveType::C128: + return DataType::kComplexDouble; + default: + return xla::InternalError( + "AsBlasDataType: unsupported type: %s", + xla::primitive_util::LowercasePrimitiveTypeName(dtype)); + } +} + +xla::StatusOr GetBlasComputationType( + DataType lhs_dtype, DataType output_dtype, int64_t /*compute_precision*/) { + + auto f16_comp = Fast_16F_Enabled() ? + ComputationType::kF16AsF32 : ComputationType::kF32, + bf16_comp = Fast_16F_Enabled() ? + ComputationType::kBF16AsF32 : ComputationType::kF32; + + switch (output_dtype) { + case DataType::kHalf: // fall-through + return f16_comp; + case DataType::kBF16: + return bf16_comp; + case DataType::kFloat: // fall-through + if (lhs_dtype == DataType::kHalf) return f16_comp; + if (lhs_dtype == DataType::kBF16) return bf16_comp; + return TF32_Enabled() ? ComputationType::kTF32AsF32 + : ComputationType::kF32; + case DataType::kComplexFloat: + return ComputationType::kF32; + case DataType::kDouble: // fall-through + case DataType::kComplexDouble: + return ComputationType::kF64; + case DataType::kInt32: + return ComputationType::kI32; + default: + return xla::InternalError("GetBlasComputationType: unsupported type"); + } +} + +MatrixLayout::MatrixLayout(blas::DataType dtype_, int64_t num_rows_, + int64_t num_cols_, MatrixLayout::Order order_, + int64_t batch_size_, + absl::optional leading_dim_stride_, + absl::optional batch_stride_, + absl::optional transpose_) + : dtype(dtype_), + num_rows(num_rows_), + num_cols(num_cols_), + order(order_), + batch_size(batch_size_) { + if (!leading_dim_stride_) { + leading_dim_stride = order == Order::kRowMajor ? num_cols : num_rows; + } else { + leading_dim_stride = *leading_dim_stride_; + } + if (!batch_stride_) { + batch_stride = (batch_size > 1) ? num_rows * num_cols : 0; + } else { + batch_stride = *batch_stride_; + } + transpose = transpose_ ? *transpose_ : blas::Transpose::kNoTranspose; +} + +void MatrixLayout::Transpose() { + std::swap(num_rows, num_cols); + order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor; +} + +// BLAS GeMM's output is column-major. If we require row-major, use identity: +// C^T = (A @ B)^T = B^T @ A^T. +bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, + MatrixLayout& output, MatrixLayout* c) { + bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; + if (swap_operands) { + std::swap(lhs, rhs); + rhs.Transpose(); + // prevent layouts from being swapped two times if they are equal + if (&lhs != &rhs) { + lhs.Transpose(); + } + if (c != nullptr && c != &output) { + c->Transpose(); + } + output.Transpose(); + } + return swap_operands; +} + +/*static*/ auto BlasLt::GetMatmulPlan(const Stream* stream, + const GemmConfig& cfg) + -> xla::StatusOr { + auto blas = Get(stream); + if (blas == nullptr) { + return xla::InternalError("BlasLt is unavailable"); + } + return blas->GetMatmulPlan(cfg); +} + +/* static */ auto BlasLt::CreateGroupedMatmulPlan(Stream* stream, + const GroupedGemmConfig& cfg) -> xla::StatusOr { + auto blas = Get(stream); + if (blas == nullptr) { + return xla::InternalError("BlasLt is unavailable"); + } + return blas->GetGroupedMatmulPlan(stream, cfg); +} + +/*static*/ BlasLt* BlasLt::Get(const Stream* stream) { + auto blas = stream->parent()->AsBlas(); + return (blas != nullptr ? blas->GetBlasLt() : nullptr); +} + +DataType GetScaleType(DataType c_type, ComputationType compute_type) { + if (compute_type == ComputationType::kF32 && + c_type != DataType::kComplexFloat) { + return DataType::kFloat; + } + if (compute_type == ComputationType::kF16) return DataType::kFloat; + return c_type; +} + + +namespace { + +const std::vector TransposeNames = { + "N", // kNoTranspose + "T", // kTranspose + "C", // kConjugateTranspose +}; + +xla::StatusOr Transpose2String(blas::Transpose type) { + size_t idx = static_cast< size_t >(type); + if (idx < TransposeNames.size()) return TransposeNames[idx]; + return xla::InternalError("Unknown transpose type!"); +} + +xla::StatusOr String2Transpose(absl::string_view s) { + for(size_t i = 0; i < TransposeNames.size(); i++) { + if (s == TransposeNames[i]) return static_cast< blas::Transpose >(i); + } + return xla::InternalError("Unknown tranpose type!"); +} + +const std::vector TypeNames = { + "f32_r", //kFloat = 0, + "f64_r", //kDouble = 1, + "f16_r", //kHalf = 2, + "i8_r", //kInt8 = 3, + "i32_r", //kInt32 = 4, + "f32_c", //kComplexFloat = 5, + "f64_c", //kComplexDouble = 6, + "bf16_r", //kBF16 = 7, +}; + +xla::StatusOr Type2String(blas::DataType type) { + size_t idx = static_cast< size_t >(type); + if (idx < TypeNames.size()) return TypeNames[idx]; + return xla::InternalError("Unknown data type!"); +} + +} // namespace + +std::string ToCSVString(const GemmConfig& cfg, bool full_string) { + + ///constexpr char kCsvComment = '#'; + constexpr char kCsvSep = ','; + + const auto& L = cfg.lhs_layout, &R = cfg.rhs_layout, &O = cfg.output_layout; + + std::ostringstream oss; + auto type_a = Type2String(L.dtype).value(), + type_b = Type2String(R.dtype).value(), + type_c = Type2String(O.dtype).value(), + trans_a = Transpose2String(L.transpose).value(), + trans_b = Transpose2String(R.transpose).value(); + +// LHS: k x n +// RHS: m x k +// OUT: m x n + // VLOG(0) << "LHS: " << L.num_cols << "x" << L.num_rows; + // VLOG(0) << "RHS: " << R.num_cols << "x" << R.num_rows; + // VLOG(0) << "OUT: " << O.num_cols << "x" << O.num_rows; + int n = L.num_rows, k = L.num_cols, m = O.num_cols; + oss << m << kCsvSep << n << kCsvSep << k << kCsvSep + << O.batch_size << kCsvSep << trans_a << kCsvSep + << trans_b << kCsvSep << type_a << kCsvSep + << type_b << kCsvSep << type_c << kCsvSep << L.leading_dim_stride + << kCsvSep << R.leading_dim_stride << kCsvSep + << O.leading_dim_stride << kCsvSep << L.batch_stride << kCsvSep + << R.batch_stride << kCsvSep << O.batch_stride; + + if (full_string) { + // NOTE: epilogue is required for MatmulPlan caching ! + oss << kCsvSep << cfg.alpha.real() << kCsvSep << cfg.alpha.imag() << kCsvSep << cfg.beta << kCsvSep << (int64_t)cfg.epilogue; + } + + return oss.str(); +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h new file mode 100644 index 00000000000000..90625d46c46a5e --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h @@ -0,0 +1,278 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/any.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace stream_executor { + +namespace gpu { + +bool GpuBlasLtEnabled(); + +xla::StatusOr AsBlasDataType(xla::PrimitiveType dtype); + +xla::StatusOr GetBlasComputationType( + blas::DataType lhs_dtype, blas::DataType output_dtype, + int64_t compute_precision); + +// Returns the type for the alpha and beta scalars. +blas::DataType GetScaleType(blas::DataType c_type, + blas::ComputationType computation_type); + +struct MatrixLayout { // plain MatrixLayout which is extended with create + // functions in matmul_utils.h + enum class Order { + kRowMajor, // Elements in the same row are contiguous in memory. + kColumnMajor, // Elements in the same column are contiguous in memory. + }; + + MatrixLayout() = default; + + MatrixLayout(blas::DataType dtype_, int64_t num_rows_, int64_t num_cols_, + Order order_, int64_t batch_size_ = 1, + absl::optional leading_dim_stride_ = {}, + absl::optional batch_stride_ = {}, + absl::optional transpose_ = {}); + + void Transpose(); + + blas::DataType dtype; + // `num_rows` / `num_cols` are for the "logical" matrix shape: + // i.e. the contracting dim has size `num_cols` for LHS operands and + // `num_rows` for RHS operands. + int64_t num_rows; + int64_t num_cols; + Order order; + int64_t batch_size; + int64_t leading_dim_stride; + // `batch_stride` is set to `0` when `batch_size == 1`. + int64_t batch_stride; + blas::Transpose transpose; +}; + +// compact version of the matrix layout to be used to pass matrices +// to underlying blas API +struct MatrixDescriptor { + DeviceMemoryBase data; + int64_t leading_dim_stride = 0; + int64_t batch_stride = 0; + blas::DataType type{}; + blas::Transpose transpose{}; + + template + DeviceMemory cast() const { + return DeviceMemory(data); + } +}; + +struct OutputMatrixDescriptor : public MatrixDescriptor { + OutputMatrixDescriptor(MatrixDescriptor&& parent) noexcept + : MatrixDescriptor(std::move(parent)) {} + int64_t batch_size = 0; + int64_t m = 0, n = 0, k = 0; + blas::ComputationType compute_type{}; +}; + +// BLAS GeMM's output is column-major. If we require row-major, use identity: +// C^T = (A @ B)^T = B^T @ A^T. +bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, + MatrixLayout& output, MatrixLayout* c = nullptr); + +struct GemmConfig; + +struct GroupedGemmConfig { + int64_t m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + const void *alpha, *beta; + blas::DataType type_a, type_b, type_c, type_d; + int64_t lda, ldb, ldc, ldd; + blas::ComputationType compute_type; + const void **a, **b, **c; + void **d; +}; + +struct BlasLt { + + static constexpr int64_t kMaxAlgorithms = 128; + + enum class Epilogue { + kDefault = 1, // No special postprocessing + kReLU = 2, // Apply point-wise ReLU function + kBias = 4, // Add broadcasted bias vector + kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform + kGELU = 32, // Apply GELU point-wise transform to the results + kGELUWithAux = 32 | 1024, // Apply GELU with auxiliary output. + kBiasThenGELU = kBias | kGELU, // Apply bias and then approximate GELU. + kBiasThenGELUWithAux = kBiasThenGELU | 1024, + }; + + // Describes the location of pointers for the scaling factors alpha and beta. + enum class PointerMode { + kHost, + kDevice, + }; + + struct MatmulAlgorithm { + absl::any opaque_algo; + size_t workspace_size; + blas::AlgorithmType id; + }; + + struct MatmulPlan { + // DoMatmul provides two sets of API for maintaning compatibility for XLA, + // and TF. One set API uses scratch_allocator to allocate workspace, and one + // set API allow uses to provide pre-allocated buffer as workspace. + + // Returns a list of supported algorithms for DoMatmul. The algorithms are + // returned in the order of increasing estimated compute time according to + // an internal heuristic. + virtual xla::StatusOr> GetAlgorithms( + size_t max_algorithm_count = kMaxAlgorithms, + size_t max_workspace_size = 1ll << 32) const = 0; + + // Algorithm needs to be set before calling ExecuteOnStream function + virtual xla::Status SetAlgorithm(const MatmulAlgorithm& algorithm) = 0; + + // The most general form: to be implemented by derived clases. + virtual xla::Status ExecuteOnStream( + Stream* stream, DeviceMemoryBase a_buffer, DeviceMemoryBase b_buffer, + DeviceMemoryBase c_buffer, DeviceMemoryBase d_buffer, + DeviceMemoryBase bias_buffer, // may be null + DeviceMemoryBase aux_buffer, // may be null + DeviceMemoryBase a_scale_buffer, DeviceMemoryBase b_scale_buffer, + DeviceMemoryBase c_scale_buffer, DeviceMemoryBase d_scale_buffer, + DeviceMemoryBase d_amax_buffer, + absl::optional workspace, + absl::optional scratch_allocator = absl::nullopt, + blas::ProfileResult* profile_result = nullptr) const = 0; + + virtual ~MatmulPlan() {} + + protected: + // might be used internally by ExecuteOnStream in derived classes + template + xla::Status DoMatmul(Stream* stream, xla::complex128 alpha, + DeviceMemoryBase a, DeviceMemoryBase b, double beta, + DeviceMemoryBase c, DeviceMemoryBase d, + DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result = nullptr) const { + Scale salpha; + if constexpr(std::is_same::value || + std::is_same::value) { + salpha = static_cast(alpha); + } else { + salpha = static_cast(alpha.real()); + } + Scale sbeta = static_cast(beta); + return DoMatmul(stream, &salpha, a, b, &sbeta, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, + d_amax, workspace, scratch_allocator, profile_result); + } + + // The most general version to be implemented by derived classes + virtual xla::Status DoMatmul( + Stream* stream, const void* alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, + DeviceMemoryBase d, DeviceMemoryBase bias, + DeviceMemoryBase aux, DeviceMemoryBase a_scale, + DeviceMemoryBase b_scale, DeviceMemoryBase c_scale, + DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result = nullptr) const = 0; + }; // class MatmulPlan + + struct GroupedMatmulPlan { + + virtual xla::StatusOr> GetAlgorithms( + size_t max_algorithm_count = kMaxAlgorithms, + size_t max_workspace_size = 1ll << 32) = 0; + + virtual xla::Status SetAlgorithm(const MatmulAlgorithm& algorithm, + ScratchAllocator * scratch_allocator) = 0; + + virtual xla::Status ExecuteOnStream(Stream *stream, + const gpu::GroupedGemmConfig& cfg, + blas::ProfileResult* profile_result = nullptr) = 0; + + virtual ~GroupedMatmulPlan() {} + }; + + using MatmulPlanPtr = std::unique_ptr; + using GroupedMatmulPlanPtr = std::unique_ptr; + + virtual xla::Status Init() = 0; + + virtual xla::StatusOr GetMatmulPlan( + const GemmConfig& cfg) const = 0; + + virtual xla::StatusOr GetGroupedMatmulPlan( + Stream *stream, + const GroupedGemmConfig& config) const = 0; + + static BlasLt* Get(const Stream* stream); + + // convenience function to create MatmulPlan directly using stream + static xla::StatusOr GetMatmulPlan(const Stream* stream, + const GemmConfig& cfg); + + // convenience function to create GroupedMatmulPlan directly using stream + static xla::StatusOr CreateGroupedMatmulPlan( + Stream* stream, const GroupedGemmConfig& cfg); + + virtual ~BlasLt() {} +}; // class BlasLt + +struct GemmConfig { // plain GemmConfig which is extended with create functions + // in matmul_utils.h + MatrixLayout lhs_layout; + MatrixLayout rhs_layout; + MatrixLayout c_layout; + MatrixLayout output_layout; + xla::complex128 alpha; + double beta; + blas::AlgorithmType algorithm; + int64_t compute_precision; + BlasLt::Epilogue epilogue; +}; + +std::string ToCSVString(const GemmConfig& cfg, bool full_string); + + +} // namespace gpu + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc new file mode 100644 index 00000000000000..8693b0e42300f8 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc @@ -0,0 +1,341 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" + +namespace stream_executor { +namespace gpu { + +bool BlasLtGemmRunner::autotune_enabled_ = true; + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +std::ostream& operator <<(std::ostream& os, const StridedGemmConfig& cfg) { + return os << "trans_a/b: " << (int)cfg.trans_a << "/" << (int)cfg.trans_b << + " m: " << cfg.m << " n: " << cfg.n << " k: " << cfg.k << + " batch_count: " << cfg.batch_count << + " lda: " << cfg.lda << " ldb: " << cfg.ldb << " ldc: " << cfg.ldc << + " stride_a: " << cfg.stride_a << " stride_b: " << cfg.stride_b << + " stride_c: " << cfg.stride_c << + " type_a: " << (int)cfg.type_a << " type_b: " << (int)cfg.type_b << + " type_c: " << (int)cfg.type_c << + " alpha: " << cfg.alpha << " beta: " << cfg.beta; +} + +BlasLtGemmRunner::BlasLtGemmRunner(StreamExecutor *parent) : + mutex_(std::make_unique< absl::Mutex >()), + autotune_config_(std::make_unique< xla::gpu::AutotuneConfig >( + xla::gpu::DeviceConfig{parent, nullptr}, + xla::GetDebugOptionsFromFlags())) + { } + +BlasLtGemmRunner::~BlasLtGemmRunner() { } + + +/*static*/ BlasLtGemmRunner& BlasLtGemmRunner::i(const Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets a different cache instance + static std::vector> meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) { + autotune_enabled_ = xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0; + res.reset(new BlasLtGemmRunner(stream->parent())); + xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce(*res->autotune_config_); + } + return *res; +} + +template < class TuneFunc > +xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > BlasLtGemmRunner::Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func) { + gpu::BlasLt::MatmulAlgorithm best_algo; + float best_ms = std::numeric_limits< float >::max(), total_ms = 0; + uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0; + + for (uint32_t j = 0; j < algorithms.size(); j++) { + const auto& algo = algorithms[j]; + if (!benchmark_func(algo, nullptr).ok()) continue; + + blas::ProfileResult profile; + for (i = 0, total_ms = 0; i < n_total; i++) { + auto res = benchmark_func(algo, &profile); + if (!res.ok() || !profile.is_valid()) { + VLOG(1) << j << ": gemm algorithm is not valid: " /* << res.error_message() */; + break; + } + if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms(); + } + if (i < n_total) continue; // invalid algorithm + total_ms /= n_iters; + VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took " + << total_ms << "ms, workspace: " << algo.workspace_size; + if (total_ms < best_ms) { + best_ms = total_ms, best_algo = algo; + } + } // for algorithms + if (!best_algo.opaque_algo.has_value()) { + return xla::InternalError("No valid gemm algorithms found!"); + } + return best_algo; +} + +xla::StatusOr< std::array< uint64_t, 3 >> BlasLtGemmRunner::ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count) { + + uint64_t bsa = 0, bsb = 0, bsc = 0; + using CT = const uint8_t; + for(int64 i = 0; i < batch_count-1; i++) { + uint64_t da = (CT *)a[i + 1]->opaque() - (CT *)a[i]->opaque(), + db = (CT *)b[i + 1]->opaque() - (CT *)b[i]->opaque(), + dc = (CT *)c[i + 1]->opaque() - (CT *)c[i]->opaque(); + if(i == 0) { + bsa = da, bsb = db, bsc = dc; + } else if(!(bsa == da && bsb == db && bsc == dc)) { // strides mismatch + return xla::InternalError("Strides are not consistent!"); + } + } + return std::array< uint64_t, 3 >{ bsa, bsb, bsc }; +} + +xla::Status BlasLtGemmRunner::RunBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator) +{ + + TF_ASSIGN_OR_RETURN(auto compute_type, + gpu::GetBlasComputationType(type_a, type_c, 0)); + + GroupedGemmConfig cfg{ + .m = (int64)m, + .n = (int64)n, + .k = (int64)k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .type_d = type_c, + .lda = (int64)lda, + .ldb = (int64)ldb, + .ldc = (int64)ldc, + .ldd = (int64)ldc, + .compute_type = compute_type, + .a = a, + .b = b, + .c = const_cast< const void **>(c), + .d = c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = grouped_gemm_map_.find(cfg); + if (res == grouped_gemm_map_.end()) { + // NOTE: we assume that pointers a,b,c come from the device mem + // hence we need to block stream here + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::CreateGroupedMatmulPlan(&stream, cfg)); + res = grouped_gemm_map_.emplace(cfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << stream.parent() << ": new GGemm config: " << + grouped_gemm_map_.size() << " #valid algorithms: " << algorithms.size(); + + BlasLt::MatmulAlgorithm best_algo; + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No GG algorithms found!"); + best_algo = algorithms[0]; // otherwise use default algorithm + } else { + TF_ASSIGN_OR_RETURN(auto best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo, allocator); + } + return res->second->ExecuteOnStream(&stream, cfg, profile); + })); + } + TF_RETURN_IF_ERROR(res->second->SetAlgorithm(best_algo, allocator)); + } + return res->second->ExecuteOnStream(&stream, cfg); +} + +xla::Status BlasLtGemmRunner::RunStridedBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) +{ + StridedGemmConfig scfg{ + .m = m, + .n = n, + .k = k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .lda = lda, + .ldb = ldb, + .ldc = ldc, + .stride_a = stride_a, + .stride_b = stride_b, + .stride_c = stride_c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = strided_gemm_map_.find(scfg); + while (res == strided_gemm_map_.end()) { + int64 row_a = m, col_a = k, row_b = k, col_b = n; + if (trans_a == blas::Transpose::kTranspose) std::swap(row_a, col_a); + if (trans_b == blas::Transpose::kTranspose) std::swap(row_b, col_b); + + auto order = MatrixLayout::Order::kColumnMajor; + GemmConfig cfg = { + .lhs_layout = MatrixLayout(type_a, row_a, col_a, order, batch_count, + lda, stride_a, trans_a), + + .rhs_layout = MatrixLayout(type_b, row_b, col_b, order, batch_count, + ldb, stride_b, trans_b), + + .c_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .output_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .alpha = alpha, + .beta = beta, + .compute_precision = -1, + .epilogue = gpu::BlasLt::Epilogue::kDefault, + }; + + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::GetMatmulPlan(&stream, cfg)); + res = strided_gemm_map_.emplace(scfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << &stream << " dev " << stream.parent() << '(' << + stream.parent()->device_ordinal() << "): new StridedBatched config: " + << strided_gemm_map_.size() << " #algorithms: " << algorithms.size(); + + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No algorithms found!"); + res->second->SetAlgorithm(algorithms[0]); + break; + } + + BlasLt::MatmulAlgorithm best_algo{ .id = blas::kNoAlgorithm }; + xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/false)); + auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key); + if (opt_res.has_value()) { + auto id = *opt_res; + for (const auto& algo : algorithms) { + if (algo.id == id) best_algo = algo; + } + if (best_algo.id == blas::kNoAlgorithm) { + LOG(WARNING) << "Best algorithm not valid: need to autotune.."; + } + } + + if (best_algo.id == blas::kNoAlgorithm) { + TF_ASSIGN_OR_RETURN(best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo); + } + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator, // allocator + profile); + })); + xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id; + // reread algorithm ID from cache again (in case some other thread has + // already added this config to the cache to be sure we use the same ID) + auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache(key, ares, + *autotune_config_); + + if (new_id != best_algo.id) { + for (const auto& algo : algorithms) { + if (algo.id == new_id) best_algo = algo; + } + } + } // best_algo.id == blas::kNoAlgorithm + + res->second->SetAlgorithm(best_algo); + break; + } // while + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator); // allocator +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h new file mode 100644 index 00000000000000..97c18a5e64ed42 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h @@ -0,0 +1,260 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/util.h" + +using tensorflow::gtl::ArraySlice; +typedef ::std::int64_t int64; + + +namespace xla { +namespace gpu { +class AutotuneConfig; +} +} + +namespace stream_executor { + +namespace gpu { + +struct StridedGemmConfig { + int64 m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + xla::complex128 alpha; + double beta; + blas::DataType type_a, type_b, type_c; + int64 lda, ldb, ldc; + int64 stride_a, stride_b, stride_c; +}; + +namespace { + +auto AsTuple(const GroupedGemmConfig& p) { + // NOTE: alpha, beta and data pointers are not included in cache !! + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, + p.type_a, p.type_b, p.type_c, p.type_d, + p.lda, p.ldb, p.ldc, p.ldd, + p.compute_type); +} + +auto AsTuple(const StridedGemmConfig& p) { + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, p.alpha.real(), p.alpha.imag(), p.beta, + p.type_a, p.type_b, p.type_c, + p.lda, p.ldb, p.ldc, + p.stride_a, p.stride_b, p.stride_c); +} + +} // namespace + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs); +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs); + +template +H AbslHashValue(H h, const GroupedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +template +H AbslHashValue(H h, const StridedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +struct BlasLtGemmRunner { + + static BlasLtGemmRunner& i(const Stream *stream); + + template < class Scalar > + xla::complex128 Convert(Scalar x) { + if constexpr(std::is_same::value || + std::is_same::value) { + return static_cast< xla::complex128 >(x); + } else { + return static_cast< double >(x); + } + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, + const DeviceMemory& b, int64 ldb, + Scalar beta, DeviceMemory *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, 0, type_b, b, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, 0, 1, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, + const TypeB *b, int64 ldb, + Scalar beta, TypeC *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, DeviceMemoryBase{const_cast< TypeA *>(a)}, lda, 0, + type_b, DeviceMemoryBase{const_cast< TypeB *>(b)}, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, 0, 1, allocator); + } + + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, int64 stride_a, + const TypeB* b, int64 ldb, int64 stride_b, + Scalar beta, TypeC* c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, + DeviceMemoryBase{const_cast(a)}, lda, stride_a, type_b, + DeviceMemoryBase{const_cast(a)}, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, int64 stride_a, + const DeviceMemory& b, int64 ldb, int64 stride_b, + Scalar beta, DeviceMemory *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, stride_a, type_b, b, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, Scalar alpha, + const ArraySlice *> &a, int64 lda, + const ArraySlice *> &b, int64 ldb, Scalar beta, + const ArraySlice *> &c, int64 ldc, + int64 batch_count, ScratchAllocator* allocator) { + + // NOTE: Scalar types shall be verified for correctness vs T!! + auto type = dnn::ToDataType::value; + auto cvt = [](auto x){ + using TT = ArraySlice; + auto ptr = reinterpret_cast(&x); + return *reinterpret_cast(ptr); + }; + + auto res = ContiguousStrides(cvt(a), cvt(b), cvt(c), batch_count); + if (res.ok()) { + auto strides = std::move(res.value()); + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type, *a[0], lda, strides[0] / sizeof(T), + type, *b[0], ldb, strides[1] / sizeof(T), + Convert(beta).real(), // only real betas are supported!! + type, c[0], ldc, strides[2] / sizeof(T), batch_count, allocator); + } + return xla::InternalError("RunBatched: port::ArraySlice NYI!"); + } + + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, uint64 m, uint64 n, uint64 k, + Scalar alpha, const T** a, int lda, + const T** b, int ldb, Scalar beta, + T** c, int64 ldc, int64 batch_count, ScratchAllocator* allocator){ + + auto type = dnn::ToDataType::value; + return RunBatchedImpl(stream, trans_a, trans_b, m, n, k, + &alpha, type, reinterpret_cast< const void **>(a), lda, + type, reinterpret_cast< const void **>(b), ldb, &beta, + type, reinterpret_cast< void **>(c), ldc, batch_count, allocator); + } + + ~BlasLtGemmRunner(); + BlasLtGemmRunner& operator=(BlasLtGemmRunner&& rhs) noexcept = default; + BlasLtGemmRunner(BlasLtGemmRunner&& rhs) noexcept = default; + +private: + explicit BlasLtGemmRunner(StreamExecutor *parent); + + template < class TuneFunc > + xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func); + + + xla::Status RunBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator); + + xla::Status RunStridedBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator); + + xla::StatusOr< std::array< uint64_t, 3 >> ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count); + + static bool autotune_enabled_; + std::unique_ptr< absl::Mutex > mutex_; + std::unique_ptr< xla::gpu::AutotuneConfig > autotune_config_; + absl::flat_hash_map grouped_gemm_map_; + absl::flat_hash_map strided_gemm_map_; +}; + +} // namespace gpu + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h index 397f29b44f7643..0b06b0dd772f1b 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h @@ -264,7 +264,7 @@ class GpuDriver { // way. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15 static tsl::Status LaunchKernel( - GpuContext* context, absl::string_view kernel_name, + GpuContext* context, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_kernel.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_kernel.h index 42ca900cf55d5f..a6e4be05ca5cea 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_kernel.h +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_kernel.h @@ -37,7 +37,7 @@ class GpuKernel : public internal::KernelInterface { public: GpuKernel() : gpu_function_(nullptr), - arity_(0), + arity_(0), inprocess_(false), preferred_cache_config_(KernelCacheConfig::kNoPreference) {} // Note that the function is unloaded when the module is unloaded, and the @@ -55,6 +55,9 @@ class GpuKernel : public internal::KernelInterface { return const_cast(gpu_function_); } + void SetInProcessSymbol(bool inprocess) { inprocess_ = inprocess; } + bool IsInProcessSymbol() const { return inprocess_; } + // Returns the slot that the GpuFunctionHandle is stored within for this // object, for the CUDA API which wants to load into a GpuFunctionHandle*. GpuFunctionHandle* gpu_function_ptr() { return &gpu_function_; } @@ -82,6 +85,7 @@ class GpuKernel : public internal::KernelInterface { private: GpuFunctionHandle gpu_function_; // Wrapped CUDA kernel handle. unsigned arity_; // Number of formal parameters the kernel takes. + bool inprocess_; // Preferred (but not required) cache configuration for this kernel. KernelCacheConfig preferred_cache_config_; diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc index 1ab21ed78506ab..1d72734ef51dc7 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc @@ -15,25 +15,28 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" +#include #include #include +#include #include +#include +#include -#include "absl/base/call_once.h" #include "absl/container/fixed_array.h" -#include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" + +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h" #include "tensorflow/compiler/xla/stream_executor/kernel.h" -#include "tensorflow/compiler/xla/stream_executor/kernel_spec.h" +#include "tensorflow/compiler/xla/stream_executor/launch_dim.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/tsl/framework/allocator.h" -#include "tensorflow/tsl/platform/errors.h" -#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/math/math_util.h" + +#include "tensorflow/core/util/env_var.h" namespace stream_executor { @@ -41,7 +44,7 @@ namespace stream_executor { // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template static T RoundUpToNearest(T value, T divisor) { - return tsl::MathUtil::CeilOfRatio(value, divisor) * divisor; + return tensorflow::MathUtil::CeilOfRatio(value, divisor) * divisor; } // The size of the redzone at the end of the user buffer is rounded up to a @@ -50,20 +53,21 @@ constexpr int64_t kRhsRedzoneAlign = 4; using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; -RedzoneAllocator::RedzoneAllocator(Stream* stream, - DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts ptx_compilation_opts, - int64_t memory_limit, int64_t redzone_size, - uint8_t redzone_pattern) +RedzoneAllocator::RedzoneAllocator( + Stream* stream, + DeviceMemoryAllocator* memory_allocator, + const GpuAsmOpts& gpu_compilation_opts, + int64_t memory_limit, int64_t redzone_size, + uint8_t redzone_pattern) : device_ordinal_(stream->parent()->device_ordinal()), stream_(stream), memory_limit_(memory_limit), redzone_size_(RoundUpToNearest( redzone_size, - static_cast(tsl::Allocator::kAllocatorAlignment))), + static_cast(tensorflow::Allocator::kAllocatorAlignment))), redzone_pattern_(redzone_pattern), memory_allocator_(memory_allocator), - gpu_compilation_opts_(ptx_compilation_opts) {} + gpu_compilation_opts_(gpu_compilation_opts) {} tsl::StatusOr> RedzoneAllocator::AllocateBytes( int64_t byte_size) { @@ -106,7 +110,7 @@ tsl::StatusOr> RedzoneAllocator::AllocateBytes( redzone_size_); uint8_t pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_, - redzone_pattern_}; + redzone_pattern_}; uint32_t pattern32; std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); @@ -119,66 +123,6 @@ tsl::StatusOr> RedzoneAllocator::AllocateBytes( return data_chunk; } -// PTX blob for the function which checks that every byte in -// input_buffer (length is buffer_length) is equal to redzone_pattern. -// -// On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. -// -// Generated from: -// __global__ void redzone_checker(unsigned char* input_buffer, -// unsigned char redzone_pattern, -// unsigned long long buffer_length, -// int* out_mismatched_ptr) { -// unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); -// } -// -// Code must compile for the oldest GPU XLA may be compiled for. -static const char* redzone_checker_ptx = R"( -.version 4.2 -.target sm_30 -.address_size 64 - -.visible .entry redzone_checker( - .param .u64 input_buffer, - .param .u8 redzone_pattern, - .param .u64 buffer_length, - .param .u64 out_mismatch_cnt_ptr -) -{ - .reg .pred %p<3>; - .reg .b16 %rs<3>; - .reg .b32 %r<6>; - .reg .b64 %rd<8>; - - ld.param.u64 %rd6, [buffer_length]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.u64.u32 %rd3, %r4; - setp.ge.u64 %p1, %rd3, %rd6; - @%p1 bra LBB6_3; - ld.param.u8 %rs1, [redzone_pattern]; - ld.param.u64 %rd4, [input_buffer]; - cvta.to.global.u64 %rd2, %rd4; - add.s64 %rd7, %rd2, %rd3; - ld.global.u8 %rs2, [%rd7]; - setp.eq.s16 %p2, %rs2, %rs1; - @%p2 bra LBB6_3; - ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; - cvta.to.global.u64 %rd1, %rd5; - atom.global.add.u32 %r5, [%rd1], 1; -LBB6_3: - ret; -} -)"; - -// The PTX in redzone_checker_ptx has to be launched with specified types -// in the specified order. -using ComparisonKernelT = TypedKernel, uint8_t, uint64_t, - DeviceMemory>; // Check that redzones weren't overwritten on a host. // @@ -188,8 +132,9 @@ static tsl::StatusOr CheckRedzoneHost( absl::string_view name, Stream* stream, uint8_t redzone_pattern) { uint64_t size = redzone.size(); auto redzone_data = std::make_unique(size); - TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size) - .BlockHostUntilDone()); + stream->ThenMemcpy(redzone_data.get(), redzone, size); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + std::array pattern_arr; pattern_arr.fill(redzone_pattern); @@ -217,35 +162,42 @@ static tsl::StatusOr CheckRedzoneHost( // Run the redzone checker on the provided buffer redzone. // // Increment out_param if mismatch occurs. -static tsl::Status RunRedzoneChecker( - Stream* stream, const DeviceMemory& redzone, - uint8_t redzone_pattern, const DeviceMemory& out_param, - const ComparisonKernelT& comparison_kernel) { +static tsl::Status RunRedzoneChecker(Stream* stream, + const DeviceMemory& redzone, + uint8_t redzone_pattern, + const DeviceMemory& out_param, + const ComparisonKernel& comparison_kernel) { StreamExecutor* executor = stream->parent(); + if (redzone.size() == 0) { + return tsl::OkStatus(); + } + int64_t num_elements = redzone.size(); int64_t threads_per_block = std::min( executor->GetDeviceDescription().threads_per_block_limit(), num_elements); int64_t block_count = - tsl::MathUtil::CeilOfRatio(num_elements, threads_per_block); + tensorflow::MathUtil::CeilOfRatio(num_elements, threads_per_block); TF_RETURN_IF_ERROR(stream->ThenLaunch( - ThreadDim(threads_per_block), BlockDim(block_count), comparison_kernel, - redzone, redzone_pattern, redzone.size(), out_param)); - return ::tsl::OkStatus(); + ThreadDim(threads_per_block), BlockDim(block_count), + comparison_kernel, redzone, redzone_pattern, + redzone.size(), out_param)); + return tsl::OkStatus(); } // Since we reuse the same buffer for multiple checks, we re-initialize redzone // with a NaN pattern after a failed check. // // This function is blocking, since redzone failing is a rare event. -static tsl::Status ReinitializeRedzone(Stream* stream, DeviceMemoryBase redzone, - uint8_t redzone_pattern) { +static tsl::Status ReinitializeRedzone(Stream* stream, + DeviceMemoryBase redzone, + uint8_t redzone_pattern) { absl::FixedArray redzone_array(redzone.size()); redzone_array.fill(redzone_pattern); stream->ThenMemcpy(&redzone, redzone_array.data(), redzone.size()); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - return ::tsl::OkStatus(); + return tsl::OkStatus(); } // Check redzones around the user allocation. @@ -254,7 +206,7 @@ static tsl::Status ReinitializeRedzone(Stream* stream, DeviceMemoryBase redzone, static tsl::StatusOr CheckRedzonesForBuffer( Stream* stream, DeviceMemoryBase memory, const DeviceMemory& out_param, - const ComparisonKernelT& comparison_kernel, int64_t user_allocation_size, + const ComparisonKernel& comparison_kernel, int64_t user_allocation_size, uint64_t redzone_size, uint8_t redzone_pattern) { StreamExecutor* executor = stream->parent(); int64_t rhs_slop = @@ -273,10 +225,10 @@ static tsl::StatusOr CheckRedzonesForBuffer( executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size, /*element_count=*/redzone_size + rhs_slop); - TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, - out_param, comparison_kernel)); - TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, rhs_redzone, redzone_pattern, - out_param, comparison_kernel)); + TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, out_param, + comparison_kernel)); + TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, rhs_redzone, redzone_pattern, out_param, + comparison_kernel)); int64_t result; CHECK_EQ(out_param.size(), sizeof(result)); stream->ThenMemcpy(&result, out_param, sizeof(result)); @@ -304,47 +256,36 @@ static tsl::StatusOr CheckRedzonesForBuffer( } tsl::StatusOr RedzoneAllocator::CheckRedzones() const { + // add for PPU + static bool found_ppu_device = [] { + bool found_ppu = false; + if (!tensorflow::ReadBoolFromEnvVar("FOUND_PPU_DEVICE", false, &found_ppu).ok()) { + return false; + } + if (found_ppu) { + return true; + } + return false; + }(); + if (found_ppu_device) { + return RedzoneCheckStatus::OK(); + } StreamExecutor* executor = stream_->parent(); - absl::Span compiled_ptx = {}; - tsl::StatusOr> compiled_ptx_or = - CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx, - gpu_compilation_opts_); - if (compiled_ptx_or.ok()) { - compiled_ptx = compiled_ptx_or.value(); - } else { - static absl::once_flag ptxas_not_found_logged; - absl::call_once(ptxas_not_found_logged, [&]() { - LOG(WARNING) << compiled_ptx_or.status().ToString() - << "\nRelying on driver to perform ptx compilation. " - << "\nModify $PATH to customize ptxas location." - << "\nThis message will be only logged once."; - }); - } + TF_ASSIGN_OR_RETURN( + const ComparisonKernel* kernel, + GetComparisonKernel(stream_->parent(), gpu_compilation_opts_)); ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); stream_->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); -#if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN( - std::shared_ptr loaded_kernel, - (LoadKernelOrGetPtr, uint8_t, uint64_t, - DeviceMemory>( - executor, "redzone_checker", redzone_checker_ptx, compiled_ptx))); -#else - TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded_kernel, - (executor->CreateTypedKernel, uint8, uint64_t, - DeviceMemory>( - "redzone_checker", redzone_checker_ptx, compiled_ptx))); -#endif // GOOGLE_CUDA for (const auto& buf_and_size : allocated_buffers_) { TF_ASSIGN_OR_RETURN( RedzoneCheckStatus redzone_status, CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(), - *loaded_kernel, buf_and_size.second, + *kernel, buf_and_size.second, redzone_size_, redzone_pattern_)); if (!redzone_status.ok()) { return redzone_status; diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h index 2e3e5ba48f65e3..a015a4638e8ba0 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h @@ -13,18 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ -#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ #include +#include +#include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" -#include "tensorflow/tsl/lib/math/math_util.h" + namespace stream_executor { @@ -43,12 +46,12 @@ class RedzoneAllocator : public ScratchAllocator { public: static constexpr int64_t kDefaultRedzoneSize = 1LL << 23; // 8MiB per side, 16MiB total. - static constexpr uint8_t kDefaultRedzonePattern = -1; // NOLINT + static constexpr uint8 kDefaultRedzonePattern = -1; RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts gpu_compilation_opts_, - int64_t memory_limit = (1LL << 32), // 4GB + const GpuAsmOpts& gpu_compilation_opts_, + int64_t memory_limit = (1LL << 32), int64_t redzone_size = kDefaultRedzoneSize, - uint8_t redzone_pattern = kDefaultRedzonePattern); + uint8 redzone_pattern = kDefaultRedzonePattern); // Redzones don't count towards the memory limit. int64_t GetMemoryLimitInBytes() override { return memory_limit_; } @@ -65,8 +68,7 @@ class RedzoneAllocator : public ScratchAllocator { RedzoneCheckStatus() = default; RedzoneCheckStatus(absl::string_view buffer_name, void* user_buffer_address, - int64_t offset, uint64_t expected_value, - uint64_t actual_value) + int64_t offset, uint64 expected_value, uint64 actual_value) : buffer_name(buffer_name), user_buffer_address(user_buffer_address), offset(offset), @@ -82,14 +84,14 @@ class RedzoneAllocator : public ScratchAllocator { std::string buffer_name = {}; void* user_buffer_address = nullptr; int64_t offset = 0; - uint64_t expected_value = 0; - uint64_t actual_value = 0; + uint64 expected_value = 0; + uint64 actual_value = 0; }; // Determines whether redzones around all allocated buffers are unmodified. // // Reinitializes redzones to the expected value, so that the same buffer - // could be reused for multiple checks. + // can be reused for multiple checks. // // Returns: // @@ -98,7 +100,6 @@ class RedzoneAllocator : public ScratchAllocator { // redzone has been detected. // - A stream error, if loading or launching the kernel has failed. tsl::StatusOr CheckRedzones() const; - Stream* stream() const { return stream_; } private: @@ -114,7 +115,7 @@ class RedzoneAllocator : public ScratchAllocator { // returned to users will be misaligned. const int64_t redzone_size_; - const uint8_t redzone_pattern_; + const uint8 redzone_pattern_; DeviceMemoryAllocator* memory_allocator_; GpuAsmOpts gpu_compilation_opts_; @@ -132,4 +133,4 @@ class RedzoneAllocator : public ScratchAllocator { } // namespace stream_executor -#endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h new file mode 100644 index 00000000000000..081f0ec0180d23 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_KERNEL_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_KERNEL_H_ + +#include + +#include "tensorflow/tsl/platform/statusor.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" +#include "tensorflow/compiler/xla/stream_executor/kernel.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" + +namespace stream_executor { +using ComparisonKernel = TypedKernel, uint8_t, uint64_t, + DeviceMemory>; + +// Returns a GPU kernel that checks a memory location for redzone patterns. +// Parameters are (buffer_address, redzone_pattern, buffer_length, +// mismatch_count_ptr). For each byte in buffer `[buffer_address : +// buffer_address +// + buffer_length]` that is not equal to `redzone_pattern`, +// `*mismatch_count_ptr` gets incremented by 1. +tsl::StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts gpu_asm_opts); + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_KERNEL_H_ diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc new file mode 100644 index 00000000000000..fccdfc965d40fd --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc @@ -0,0 +1,147 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.h" +#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h" +#include "tensorflow/compiler/xla/stream_executor/kernel.h" +#include "tensorflow/compiler/xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor { +// Maintains a cache of pointers to loaded kernels +template +static StatusOr*> LoadKernelOrGetPtr( + StreamExecutor* executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data) { + using KernelPtrCacheKey = + std::tuple; + + static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); + static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = + *new absl::node_hash_map>(); + CUcontext current_context = cuda::CurrentContextOrDie(); + KernelPtrCacheKey kernel_ptr_cache_key{current_context, kernel_name, ptx}; + absl::MutexLock lock(&kernel_ptr_cache_mutex); + + auto it = kernel_ptr_cache.find(kernel_ptr_cache_key); + if (it == kernel_ptr_cache.end()) { + TF_ASSIGN_OR_RETURN(TypedKernel loaded, + (TypedKernelFactory::Create( + executor, kernel_name, ptx, cubin_data))); + it = + kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first; + } + + CHECK(it != kernel_ptr_cache.end()); + return &it->second; +} + +// PTX blob for the function which checks that every byte in +// input_buffer (length is buffer_length) is equal to redzone_pattern. +// +// On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. +// +// Generated from: +// __global__ void redzone_checker(unsigned char* input_buffer, +// unsigned char redzone_pattern, +// unsigned long long buffer_length, +// int* out_mismatched_ptr) { +// unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); +// } +// +// Code must compile for the oldest GPU XLA may be compiled for. +static const char* redzone_checker_ptx = R"( +.version 4.2 +.target sm_30 +.address_size 64 + +.visible .entry redzone_checker( + .param .u64 input_buffer, + .param .u8 redzone_pattern, + .param .u64 buffer_length, + .param .u64 out_mismatch_cnt_ptr +) +{ + .reg .pred %p<3>; + .reg .b16 %rs<3>; + .reg .b32 %r<6>; + .reg .b64 %rd<8>; + + ld.param.u64 %rd6, [buffer_length]; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r3, %r2, %r1; + cvt.u64.u32 %rd3, %r4; + setp.ge.u64 %p1, %rd3, %rd6; + @%p1 bra LBB6_3; + ld.param.u8 %rs1, [redzone_pattern]; + ld.param.u64 %rd4, [input_buffer]; + cvta.to.global.u64 %rd2, %rd4; + add.s64 %rd7, %rd2, %rd3; + ld.global.u8 %rs2, [%rd7]; + setp.eq.s16 %p2, %rs2, %rs1; + @%p2 bra LBB6_3; + ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; + cvta.to.global.u64 %rd1, %rd5; + atom.global.add.u32 %r5, [%rd1], 1; +LBB6_3: + ret; +} +)"; + +StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts gpu_asm_opts) { + absl::Span compiled_ptx = {}; + StatusOr> compiled_ptx_or = + CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx, + gpu_asm_opts); + if (compiled_ptx_or.ok()) { + compiled_ptx = compiled_ptx_or.value(); + } else { + static absl::once_flag ptxas_not_found_logged; + absl::call_once(ptxas_not_found_logged, [&]() { + LOG(WARNING) << compiled_ptx_or.status() + << "\nRelying on driver to perform ptx compilation. " + << "\nModify $PATH to customize ptxas location." + << "\nThis message will be only logged once."; + }); + } + + return LoadKernelOrGetPtr, uint8_t, uint64_t, + DeviceMemory>( + executor, "redzone_checker", redzone_checker_ptx, compiled_ptx); +} +} // namespace stream_executor + diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc new file mode 100644 index 00000000000000..2eba18a5b0e985 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_kernel.h" +#include "tensorflow/compiler/xla/stream_executor/kernel.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" + +namespace { +__global__ void redzone_checker_kernel(uint8_t* input_buffer, + uint8_t redzone_pattern, + uint64_t buffer_length, + uint32_t* out_mismatched_ptr) { + uint64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); +} +} // namespace + +namespace stream_executor { + +tsl::StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts /*gpu_asm_opts*/) { + + static auto kernel = + executor->CreateTypedKernel, uint8_t, uint64_t, + DeviceMemory>( + "redzone_checker", reinterpret_cast( + redzone_checker_kernel)); + + if (!kernel.ok()) return kernel.status(); + return kernel.value().get(); +} + +} // namespace stream_executor + diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_test.cc b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_test.cc new file mode 100644 index 00000000000000..92411f7b5dd38c --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -0,0 +1,155 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" + +#include +#include + +#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu_asm_opts.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_init.h" +#include "tensorflow/compiler/xla/stream_executor/platform.h" +#include "tensorflow/compiler/xla/stream_executor/platform_manager.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_memory_allocator.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { +namespace gpu { +namespace { + +using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; + +static void EXPECT_REDZONE_OK(StatusOr status) { + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(status.value().ok()); +} + +static void EXPECT_REDZONE_VIOLATION( + StatusOr status) { + EXPECT_TRUE(status.ok()); + EXPECT_FALSE(status.value().ok()); +} + +TEST(RedzoneAllocatorTest, WriteToRedzone) { + constexpr int64_t kRedzoneSize = 1 << 23; // 8MiB redzone on each side + // Redzone pattern should not be equal to zero; otherwise modify_redzone will + // break. + constexpr uint8_t kRedzonePattern = 0x7e; + + // Allocate 32MiB + 1 byte (to make things misaligned) + constexpr int64_t kAllocSize = (1 << 25) + 1; + + Platform* platform = + PlatformManager::PlatformWithName(tensorflow::GpuPlatformName()).value(); + StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); + GpuAsmOpts opts; + StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_exec->CreateStream()); + RedzoneAllocator allocator(stream.get(), &se_allocator, opts, + /*memory_limit=*/(1LL << 32), + /*redzone_size=*/kRedzoneSize, + /*redzone_pattern=*/kRedzonePattern); + TF_ASSERT_OK_AND_ASSIGN(DeviceMemory buf, + allocator.AllocateBytes(/*byte_size=*/kAllocSize)); + + EXPECT_REDZONE_OK(allocator.CheckRedzones()); + + char* buf_addr = reinterpret_cast(buf.opaque()); + DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize); + DeviceMemoryBase rhs_redzone(buf_addr + kAllocSize, kRedzoneSize); + + // Check that the redzones are in fact filled with kRedzonePattern. + auto check_redzone = [&](DeviceMemoryBase redzone, absl::string_view name) { + std::vector host_buf(kRedzoneSize); + TF_ASSERT_OK(stream->ThenMemcpy(host_buf.data(), redzone, kRedzoneSize)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + const int64_t kMaxMismatches = 16; + int64_t mismatches = 0; + for (int64_t i = 0; i < host_buf.size(); ++i) { + if (mismatches == kMaxMismatches) { + ADD_FAILURE() << "Hit max number of mismatches; skipping others."; + break; + } + if (host_buf[i] != kRedzonePattern) { + ++mismatches; + EXPECT_EQ(host_buf[i], kRedzonePattern) + << "at index " << i << " of " << name << " redzone"; + } + } + }; + + check_redzone(lhs_redzone, "lhs"); + check_redzone(rhs_redzone, "rhs"); + + // Modifies a redzone, checks that RedzonesAreUnmodified returns false, then + // reverts it back to its original value and checks that RedzonesAreUnmodified + // returns true. + auto modify_redzone = [&](DeviceMemoryBase redzone, int64_t offset, + absl::string_view name) { + SCOPED_TRACE(absl::StrCat(name, ", offset=", offset)); + DeviceMemoryBase redzone_at_offset( + reinterpret_cast(redzone.opaque()) + offset, 1); + char old_redzone_value = 0; + { EXPECT_REDZONE_OK(allocator.CheckRedzones()); } + TF_ASSERT_OK(stream->ThenMemcpy(&old_redzone_value, redzone_at_offset, 1)); + TF_ASSERT_OK(stream->MemZero(&redzone_at_offset, 1)); + EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones()); + + // Checking reinitializes the redzone. + EXPECT_REDZONE_OK(allocator.CheckRedzones()); + }; + + modify_redzone(lhs_redzone, /*offset=*/0, "lhs"); + modify_redzone(lhs_redzone, /*offset=*/kRedzoneSize - 1, "lhs"); + modify_redzone(rhs_redzone, /*offset=*/0, "rhs"); + modify_redzone(rhs_redzone, /*offset=*/kRedzoneSize - 1, "rhs"); +} + +// Older CUDA compute capabilities (<= 2.0) have a limitation that grid +// dimension X cannot be larger than 65535. +// +// Make sure we can launch kernels on sizes larger than that, given that the +// maximum number of threads per block is 1024. +TEST(RedzoneAllocatorTest, VeryLargeRedzone) { + // Make sure the redzone size would require grid dimension > 65535. + constexpr int64_t kRedzoneSize = 65535 * 1024 + 1; + Platform* platform = + PlatformWithName(tensorflow::GpuPlatformName()).value(); + + StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); + GpuAsmOpts opts; + StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_exec->CreateStream()); + RedzoneAllocator allocator( + stream.get(), &se_allocator, opts, + /*memory_limit=*/ (1LL << 32), + /*redzone_size=*/kRedzoneSize, + /*redzone_pattern=*/-1); + (void)allocator.AllocateBytes(/*byte_size=*/1); + EXPECT_REDZONE_OK(allocator.CheckRedzones()); +} + +} // namespace gpu +} // namespace stream_executor + diff --git a/tensorflow/compiler/xla/stream_executor/kernel_spec.cc b/tensorflow/compiler/xla/stream_executor/kernel_spec.cc index 1fd776baf10046..1da53d9273ecf4 100644 --- a/tensorflow/compiler/xla/stream_executor/kernel_spec.cc +++ b/tensorflow/compiler/xla/stream_executor/kernel_spec.cc @@ -22,6 +22,9 @@ namespace stream_executor { KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernelname) : kernelname_(std::string(kernelname)) {} +InProcessSymbol::InProcessSymbol(void *symbol, absl::string_view kernel_name) + : KernelLoaderSpec(kernel_name), symbol_(symbol) {} + OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename, absl::string_view kernelname) : KernelLoaderSpec(kernelname), filename_(std::string(filename)) {} @@ -157,6 +160,14 @@ const char *CudaPtxInMemory::original_text(int compute_capability_major, return ptx_iter->second; } +MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddInProcessSymbol( + void *symbol, absl::string_view kernel_name) { + CHECK(in_process_symbol_ == nullptr); + in_process_symbol_ = + std::make_shared(symbol, std::string(kernel_name)); + return this; +} + OpenCLTextOnDisk::OpenCLTextOnDisk(absl::string_view filename, absl::string_view kernelname) : OnDiskKernelLoaderSpec(filename, kernelname) {} diff --git a/tensorflow/compiler/xla/stream_executor/kernel_spec.h b/tensorflow/compiler/xla/stream_executor/kernel_spec.h index d4bd9486a4dfd1..f5b600b8955d75 100644 --- a/tensorflow/compiler/xla/stream_executor/kernel_spec.h +++ b/tensorflow/compiler/xla/stream_executor/kernel_spec.h @@ -86,6 +86,18 @@ class KernelLoaderSpec { SE_DISALLOW_COPY_AND_ASSIGN(KernelLoaderSpec); }; +// Loads kernel from in process symbol pointer (e.g. pointer to C++ device +// function). +class InProcessSymbol : public KernelLoaderSpec { + public: + InProcessSymbol(void *symbol, absl::string_view kernel_name); + + void *symbol() const { return symbol_; } + + private: + void *symbol_; +}; + // An abstract kernel loader spec that has an associated file path, where // there's a canonical suffix for the filename; e.g. see CudaPtxOnDisk whose // canonical filename suffix is ".ptx". @@ -279,6 +291,7 @@ class MultiKernelLoaderSpec { // Convenience getters for testing whether these platform variants have // kernel loader specifications available. + bool has_in_process_symbol() const { return in_process_symbol_ != nullptr; } bool has_cuda_ptx_on_disk() const { return cuda_ptx_on_disk_ != nullptr; } bool has_cuda_cubin_on_disk() const { return cuda_cubin_on_disk_ != nullptr; } bool has_cuda_cubin_in_memory() const { @@ -291,6 +304,10 @@ class MultiKernelLoaderSpec { // Accessors for platform variant kernel load specifications. // Precondition: corresponding has_* is true. + const InProcessSymbol &in_process_symbol() const { + CHECK(has_in_process_symbol()); + return *in_process_symbol_; + } const CudaPtxOnDisk &cuda_ptx_on_disk() const { CHECK(has_cuda_ptx_on_disk()); return *cuda_ptx_on_disk_; @@ -328,6 +345,8 @@ class MultiKernelLoaderSpec { // the PTX or OpenCL being loaded. Also be aware that in CUDA C++ the kernel // name may be mangled by the compiler if it is not declared in an // extern "C" scope. + MultiKernelLoaderSpec *AddInProcessSymbol(void *symbol, + absl::string_view kernel_name); MultiKernelLoaderSpec *AddOpenCLTextOnDisk(absl::string_view filename, absl::string_view kernelname); MultiKernelLoaderSpec *AddOpenCLBinaryOnDisk(absl::string_view filename, @@ -352,6 +371,8 @@ class MultiKernelLoaderSpec { absl::string_view kernelname); private: + std::shared_ptr + in_process_symbol_; // In process symbol pointer. std::unique_ptr cuda_ptx_on_disk_; // PTX text that resides in a file. std::unique_ptr diff --git a/tensorflow/compiler/xla/stream_executor/rocm/BUILD b/tensorflow/compiler/xla/stream_executor/rocm/BUILD index 68351dbcbda430..7386fcd81f05fd 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/BUILD +++ b/tensorflow/compiler/xla/stream_executor/rocm/BUILD @@ -7,7 +7,9 @@ load( "stream_executor_friends", ) load("//tensorflow/tsl:tsl.bzl", "set_external_visibility", "tsl_copts") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", "rocm_copts") +load("@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", "rocm_copts", "if_rocm_hipblaslt" +) load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( @@ -184,6 +186,7 @@ cc_library( hdrs = if_rocm_is_configured(["rocm_blas.h"]), visibility = ["//visibility:public"], deps = if_rocm_is_configured([ + ":hipblas_lt_header", ":rocblas_if_static", ":rocblas_wrapper", ":rocm_gpu_executor", @@ -200,6 +203,7 @@ cc_library( "//tensorflow/compiler/xla/stream_executor/gpu:gpu_helpers_header", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream_header", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_timer_header", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt", "//tensorflow/compiler/xla/stream_executor/platform", "//tensorflow/compiler/xla/stream_executor:blas", "//tensorflow/compiler/xla/stream_executor/platform:dso_loader", @@ -431,6 +435,77 @@ cc_library( alwayslink = True, ) +cc_library( + name = "hipblas_lt_kernel", + hdrs = if_rocm_is_configured([ + "hip_blas_lt.h", + ]), + srcs = if_rocm_is_configured(["hip_blas_lt.cu.cc"]), + copts = rocm_copts(), + deps = if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + #MATH_DEPS + ]) +) + +cc_library( + name = "hipblaslt_if_static", + deps = if_rocm_hipblaslt([ + "@local_config_rocm//rocm:hipblaslt", + ]), +) + +cc_library( + name = "amdhipblaslt_plugin", + srcs = if_rocm_is_configured([ + "hip_blas_lt.cc", + "hip_blas_utils.cc", + ]), + hdrs = if_rocm_is_configured([ + "hip_blas_lt.h", + "hipblaslt_wrapper.h", + "hip_blas_utils.h", + "rocm_blas.h", + ]), + deps = if_rocm_is_configured([ + # keep sorted + ":hipblas_lt_header", + ":hipblas_lt_kernel", + ":rocm_platform_id", + "@com_google_absl//absl/types:any", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_activation", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream_header", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_timer_header", + "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/compiler/xla/stream_executor/platform:dso_loader", + "//tensorflow/compiler/xla/stream_executor:scratch_allocator", + "@local_config_rocm//rocm:rocm_headers", + ]) + if_static([ + ":hipblaslt_if_static", + ]), + linkopts = ["-lhipblaslt"], + alwayslink = True, +) + +cc_library( + name = "hipblas_lt_header", + hdrs = if_rocm_is_configured([ + "hip_blas_lt.h", + "hipblaslt_wrapper.h", + "hip_blas_utils.h", + ]), + visibility = ["//visibility:public"], + deps = if_rocm_is_configured([ + # keep sorted + ]), +) + cc_library( name = "roctracer_if_static", deps = if_static([ @@ -483,6 +558,7 @@ cc_library( ":rocm_driver", ":rocm_platform", ":rocm_helpers", + ":amdhipblaslt_plugin", ]), alwayslink = 1, ) diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc new file mode 100644 index 00000000000000..8571f8aea156c1 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc @@ -0,0 +1,736 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define LEGACY_HIPBLAS_DIRECT +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rocm/rocm_config.h" +#include "rocm/include/rocblas/rocblas.h" +#include "rocm/include/hipblaslt/hipblaslt-ext.hpp" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/bfloat16.h" +//#include "tensorflow/core/lib/core/errors.h" + +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_activation.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" +#include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.h" +#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" + +#define SET_ATTR(setter, handle, attr, value) \ + ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter) + +// hipblasLtMatmulDescGetAttribute does not allow nullptr for the last +// argument (size_t* sizeWritten) +#define GET_ATTR(getter, handle, attr, ValueT) \ + [&]() -> xla::StatusOr { \ + ValueT value; \ + size_t size; \ + TF_RETURN_IF_ERROR(ToStatus( \ + getter(handle, attr, &value, sizeof(ValueT), &size), #getter)); \ + return std::move(value); \ + }() + +namespace stream_executor { + +namespace rocm { + +using ::xla::complex128; +using ::xla::complex64; +using tsl::bfloat16; +using namespace hipblaslt_ext; + +// void GroupGemmUpdateArgs(hipStream_t stream, +// UserArguments *dev_args, +// const gpu::GroupedGemmConfig& cfg); + +void GroupGemmUpdateArgs(hipStream_t stream, + UserArguments *dev_args, + const void **a, const void **b, const void **c, void **d, + uint32_t num_gemms); + +namespace { + +typedef struct __attribute__((packed, aligned(8))) _rocblaslt_matmul_algo { + uint8_t data[8] = {0}; + bool fallback = false; + size_t max_workspace_bytes = 0; +} rocblaslt_matmul_algo; + +static_assert(sizeof(hipblasLtMatmulAlgo_t) == sizeof(rocblaslt_matmul_algo), + "Structure size does not match!"); + +template +xla::Status SetAttr(hipblasLtMatrixLayout_t handle, + hipblasLtMatrixLayoutAttribute_t attr, T value) { + return SET_ATTR(hipblasLtMatrixLayoutSetAttribute, handle, attr, value); +} + +template +xla::StatusOr GetAttr(hipblasLtMatrixLayout_t handle, + hipblasLtMatrixLayoutAttribute_t attr) { + return GET_ATTR(hipblasLtMatrixLayoutGetAttribute, handle, attr, T); +} + +template +xla::Status SetAttr(hipblasLtMatmulDesc_t handle, + hipblasLtMatmulDescAttributes_t attr, T value) { + return SET_ATTR(hipblasLtMatmulDescSetAttribute, handle, attr, value); +} + +template +xla::StatusOr GetAttr(hipblasLtMatmulDesc_t handle, + hipblasLtMatmulDescAttributes_t attr) { + return GET_ATTR(hipblasLtMatmulDescGetAttribute, handle, attr, T); +} + +template +xla::Status SetAttr(hipblasLtMatmulPreference_t handle, + hipblasLtMatmulPreferenceAttributes_t attr, T value) { + return SET_ATTR(hipblasLtMatmulPreferenceSetAttribute, handle, attr, + value); +} + +xla::StatusOr AsHipblasLtEpilogue( + gpu::BlasLt::Epilogue epilogue) { + switch (epilogue) { + case gpu::BlasLt::Epilogue::kDefault: + return HIPBLASLT_EPILOGUE_DEFAULT; + case gpu::BlasLt::Epilogue::kReLU: + return HIPBLASLT_EPILOGUE_RELU; + case gpu::BlasLt::Epilogue::kBias: + return HIPBLASLT_EPILOGUE_BIAS; + case gpu::BlasLt::Epilogue::kBiasThenReLU: + return HIPBLASLT_EPILOGUE_RELU_BIAS; + case gpu::BlasLt::Epilogue::kGELU: + return HIPBLASLT_EPILOGUE_GELU; +#if TF_ROCM_VERSION >= 60000 + case gpu::BlasLt::Epilogue::kGELUWithAux: + return HIPBLASLT_EPILOGUE_GELU_AUX; + case gpu::BlasLt::Epilogue::kBiasThenGELU: + return HIPBLASLT_EPILOGUE_GELU_BIAS; + case gpu::BlasLt::Epilogue::kBiasThenGELUWithAux: + return HIPBLASLT_EPILOGUE_GELU_AUX_BIAS; +#endif + default: + return xla::InternalError("Unsupported epilogue: %d", + static_cast(epilogue)); + } +} + +} // namespace + +BlasLt::BlasLt(gpu::GpuExecutor* parent) + : parent_(parent), blas_lt_(nullptr, hipblasLtDestroy) {} + +xla::Status BlasLt::Init() { + hipblasLtHandle_t blas_lt; + SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtCreate(&blas_lt)); + absl::MutexLock lock(&mu_); + blas_lt_.reset(blas_lt); + return xla::OkStatus(); +} + +/*static*/ xla::StatusOr BlasLt::MatrixLayout::Create( + const gpu::MatrixLayout& m) { + + auto hipblas_data_type_ = AsHipblasDataType(m.dtype); + hipblasLtMatrixLayout_t hip_layout; + SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate( + &hip_layout, hipblas_data_type_, m.num_rows, m.num_cols, + m.leading_dim_stride)); + // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. + BlasLt::MatrixLayout layout(hip_layout, hipblas_data_type_); + + if (m.order != gpu::MatrixLayout::Order::kColumnMajor){ + return xla::InternalError("HipblasLT does not support row-major matrices"); + } + TF_RETURN_IF_ERROR(SetAttr(hip_layout, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + static_cast(m.batch_size))); + + VLOG(2) << "BlasLt::MatrixLayout::Create type: " << (int)m.dtype + << " rows: " << m.num_rows << " cols: " << m.num_cols + << " batch_size: " << m.batch_size + << " leading_dim_stride: " << m.leading_dim_stride + << " batch_stride: " << m.batch_stride; + + TF_RETURN_IF_ERROR(SetAttr(hip_layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + m.batch_stride)); + return std::move(layout); +} + +/*static*/ xla::StatusOr BlasLt::MatmulDesc::Create( + blas::ComputationType compute_type, blas::DataType scale_type, + blas::Transpose trans_a, blas::Transpose trans_b, Epilogue epilogue, + PointerMode pointer_mode) { + hipblasLtMatmulDesc_t hip_desc; + VLOG(2) << "BlasLt::MatmulDesc::Create compute_type: " << int(compute_type) + << " scale_type: " << int(scale_type) + << " epilogue: " << int(epilogue) << " trans_a: " << int(trans_a) + << " trans_b: " << int(trans_b) << " pointer_mode " + << int(pointer_mode); + auto hip_scale_type = AsHipblasDataType(scale_type); + auto hip_compute_type = AsHipblasComputeType(compute_type); + SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtMatmulDescCreate( + &hip_desc, hip_compute_type, hip_scale_type)); + + int32_t bias_flag = + static_cast(epilogue) & static_cast(Epilogue::kBias); + // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. + BlasLt::MatmulDesc desc(hip_desc, hip_compute_type, hip_scale_type, + bias_flag != 0); + if (pointer_mode != PointerMode::kHost) { + return xla::InternalError("hipblaslt does not support device pointers"); + } + + TF_RETURN_IF_ERROR(SetAttr(hip_desc, HIPBLASLT_MATMUL_DESC_TRANSA, + AsHipblasOperation(trans_a))); + TF_RETURN_IF_ERROR(SetAttr(hip_desc, HIPBLASLT_MATMUL_DESC_TRANSB, + AsHipblasOperation(trans_b))); + TF_ASSIGN_OR_RETURN(hipblasLtEpilogue_t epi, AsHipblasLtEpilogue(epilogue)); + TF_RETURN_IF_ERROR(SetAttr(hip_desc, HIPBLASLT_MATMUL_DESC_EPILOGUE, epi)); + return std::move(desc); +} + +auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, + size_t max_workspace_size) const + -> xla::StatusOr> { + max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX}); + std::vector results(max_algorithm_count); + + { + absl::MutexLock lock(&blas_lt_ref_.mu_); + TF_RET_CHECK(blas_lt_ref_.blas_lt_ != nullptr); + + hipblasLtMatmulPreference_t hip_preference; + SE_HIPBLAS_RETURN_IF_ERROR( + hipblasLtMatmulPreferenceCreate(&hip_preference)); + + // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. + Owned preference( + hip_preference, hipblasLtMatmulPreferenceDestroy); + + TF_RETURN_IF_ERROR(SetAttr( + hip_preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + max_workspace_size)); + + gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; + + // hipBlasLt requires setting the bias pointer (even a dummy one), otherwise + // no algorithms can be found for "bias epilogues". This is to be removed + // later when this limitation is gone. + if (op_desc_.has_bias_epilogue()) { + static int64_t dummyPointer = 0xACEBALL; + TF_RETURN_IF_ERROR(SetAttr( + op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &dummyPointer)); + } + + int found_algorithm_count = 0; + auto error = hipblasLtMatmulAlgoGetHeuristic( + blas_lt_ref_.blas_lt_.get(), op_desc_.get(), a_desc_.get(), + b_desc_.get(), c_desc_.get(), d_desc_.get(), preference.get(), + max_algorithm_count, results.data(), &found_algorithm_count); + if (error != 0) { + VLOG(0) << "hipblasLtMatmulAlgoGetHeuristic returned " << (int)error; + SE_HIPBLAS_RETURN_IF_ERROR(error); + } + results.resize(found_algorithm_count); + } // end mutex block + + std::vector algorithms; + algorithms.reserve(max_algorithm_count); + for (const hipblasLtMatmulHeuristicResult_t& result : results) { + if (result.state == HIPBLAS_STATUS_SUCCESS) { // Skip failed algos. + auto roc_algo = (const rocblaslt_matmul_algo*)&result.algo; + auto pindex = (int *)roc_algo->data; + algorithms.push_back({result.algo, result.workspaceSize, + static_cast< blas::AlgorithmType >(*pindex)}); + if (algorithms.size() >= max_algorithm_count) break; + } + } + return std::move(algorithms); +} + +xla::Status BlasLt::MatmulPlan::SetAlgorithm(const MatmulAlgorithm& algorithm) { + algorithm_ = algorithm; + return xla::OkStatus(); +} + +auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg) const + -> xla::StatusOr { + auto lhs_layout = cfg.lhs_layout, rhs_layout = cfg.rhs_layout, + output_layout = cfg.output_layout, c_layout = cfg.c_layout; + + // cublasLt matmul requires batch sizes to be equal. If only one operand has a + // batch, the other will be broadcast (as its batch_stride == 0). + size_t batch_size = std::max(lhs_layout.batch_size, rhs_layout.batch_size); + lhs_layout.batch_size = batch_size; + rhs_layout.batch_size = batch_size; + + bool must_swap_operands = + MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout, &c_layout); + + // Do not transpose either input. Note the cuBLASLt documentation somewhat + // incorrectly claims "A must be transposed and B non-transposed" when A and B + // are FP8 (https://docs.nvidia.com/cuda/cublas/#cublasltmatmul). In reality, + // this is only true if A and B are column-major. If A is row-major, A must + // *not* be transposed, and if B is row-major, B must be transposed. We never + // transpose A or B, and expect the caller to ensure A is row-major and B is + // column when A and B are FP8. + auto trans_a = lhs_layout.transpose, trans_b = rhs_layout.transpose; + + // if (xla::primitive_util::IsF8Type(lhs_layout.dtype) && + // lhs_layout.order == gpu::MatrixLayout::Order::kColumnMajor) { + // return xla::Internal("The F8 LHS must be column-major"); + // } + // if (xla::primitive_util::IsF8Type(rhs_layout.dtype) && + // rhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { + // return xla::Internal("The F8 RHS must be row-major"); + // } + + TF_ASSIGN_OR_RETURN(auto compute_type, gpu::GetBlasComputationType( + lhs_layout.dtype, output_layout.dtype, cfg.compute_precision)); + + if (lhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { + trans_a = blas::Transpose::kTranspose; + lhs_layout.Transpose(); + } + if (rhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { + trans_b = blas::Transpose::kTranspose; + rhs_layout.Transpose(); + } + + TF_ASSIGN_OR_RETURN( + auto op_desc, + MatmulDesc::Create(compute_type, + gpu::GetScaleType(output_layout.dtype, compute_type), + trans_a, trans_b, cfg.epilogue)); + + TF_ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout)); + TF_ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout)); + TF_ASSIGN_OR_RETURN(auto c_desc, MatrixLayout::Create(c_layout)); + TF_ASSIGN_OR_RETURN(auto d_desc, MatrixLayout::Create(output_layout)); + + // std::make_unique won't work with brace initialization in C++17 ;( + auto M = std::make_unique(*this, std::move(op_desc), + std::move(a_desc), std::move(b_desc), + std::move(c_desc), std::move(d_desc), + cfg.alpha, cfg.beta, must_swap_operands); + return xla::StatusOr{std::move(M)}; +} + +xla::Status BlasLt::MatmulPlan::DoMatmul( + Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, + const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, + DeviceMemoryBase bias, + DeviceMemoryBase aux, DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional allocator, + blas::ProfileResult* profile_result) const { + + std::unique_ptr timer; + if (profile_result != nullptr) { + timer.reset(new gpu::GpuTimer(blas_lt_ref_.parent_)); + if (!timer->Init() || !timer->Start(gpu::AsGpuStream(stream))) { + return xla::InternalError("Unable to start gpu timer"); + } + } + + if(!algorithm_.has_value()) return xla::InternalError("Algorithm is not set!"); + + void* workspace_addr = nullptr; + uint64_t workspace_size = 0; + if (workspace.has_value()) { + workspace_addr = workspace.value().opaque(); + workspace_size = workspace.value().size(); + TF_RET_CHECK(workspace_size >= algorithm_->workspace_size); + } else if (algorithm_->workspace_size > 0) { + + if (!allocator || allocator.value() == nullptr) { + return xla::InternalError("Allocator is not set: skipping solution!"); + } + + TF_ASSIGN_OR_RETURN(auto alloc, + allocator.value()->AllocateBytes(algorithm_->workspace_size)); + + workspace_addr = gpu::GpuMemoryMutable(&alloc); + workspace_size = algorithm_->workspace_size; + } + + auto palgo = absl::any_cast(&algorithm_->opaque_algo); + { + absl::MutexLock lock(&blas_lt_ref_.mu_); + TF_RET_CHECK(blas_lt_ref_.blas_lt_ != nullptr); + // We must set the bias and aux pointers while holding the mutex, to avoid a + // potential race condition from multiple threads sharing the same plan. + if (op_desc_.has_bias_epilogue() && bias != nullptr) { + TF_RETURN_IF_ERROR(SetAttr( + op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias.opaque())); + } + + if (a_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, + a_scale.opaque())); + } + if (b_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, + b_scale.opaque())); + } + if (c_scale != nullptr || d_scale != nullptr) { + return xla::InternalError( + "hipblaslt does not support c_scale or d_scale."); + } + + if (d_amax != nullptr) { + return xla::InternalError("hipblaslt does not support amax"); + } + + if (aux != nullptr) { + return xla::InternalError( + "hipblaslt does not support auxiliary inputs / outputs"); + } + + gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; + + if (palgo != nullptr) { + SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtMatmul( + blas_lt_ref_.blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), + a_desc_.get(), b.opaque(), b_desc_.get(), beta, c.opaque(), + c_desc_.get(), d.opaque(), d_desc_.get(), palgo, workspace_addr, + workspace_size, gpu::AsGpuStreamValue(stream))); + } else { + return xla::InternalError("hipblaslt: Invalid algorithm type"); + } + } + + if (profile_result != nullptr) { + if (!timer->Stop(gpu::AsGpuStream(stream))) { + return xla::InternalError("Unable to stop gpu timer"); + } + profile_result->set_algorithm(algorithm_->id); + profile_result->set_is_valid(true); + profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); + } + return xla::OkStatus(); +} + +xla::Status BlasLt::MatmulPlan::ExecuteOnStream( + Stream* stream, DeviceMemoryBase a, DeviceMemoryBase b, DeviceMemoryBase c, + DeviceMemoryBase d, DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result) const { + if (must_swap_operands_) { + std::swap(a, b); + } + + auto operand_types = std::make_tuple( + a_desc_.type(), b_desc_.type(), c_desc_.type(), d_desc_.type()); + +#define TYPED_MATMUL(SCALENTYPE, ATYPE, BTYPE, CTYPE, DTYPE) \ + if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE, DTYPE)) { \ + return gpu::BlasLt::MatmulPlan::DoMatmul< SCALENTYPE >( \ + stream, alpha_, a, b, beta_, c, d, bias, aux, a_scale, b_scale, \ + c_scale, d_scale, d_amax, workspace, scratch_allocator, \ + profile_result); \ + } + // Other data types: + TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) + TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(float, HIP_R_32F, HIP_R_32F, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(double, HIP_R_64F, HIP_R_64F, HIP_R_64F, HIP_R_64F) + TYPED_MATMUL(complex64, HIP_C_32F, HIP_C_32F, HIP_C_32F, HIP_C_32F) + TYPED_MATMUL(complex128, HIP_C_64F, HIP_C_64F, HIP_C_64F, HIP_C_64F) + +#undef TYPED_MATMUL + + return xla::InternalError("Unexpected dtype"); +} + + +BlasLt::GroupedMatmulPlan::GroupedMatmulPlan(const BlasLt& blas_lt) : + blas_lt_ref_(blas_lt) {} + +BlasLt::GroupedMatmulPlan::~GroupedMatmulPlan() { + if(host_args_ != nullptr) { + blas_lt_ref_.parent_->HostMemoryDeallocate(host_args_); + } + if(!device_args_.is_null()) { + blas_lt_ref_.parent_->Deallocate(&device_args_); + } +} + +auto BlasLt::GetGroupedMatmulPlan(Stream *stream, + const gpu::GroupedGemmConfig& cfg) const + -> xla::StatusOr { + + auto plan = std::make_unique< GroupedMatmulPlan >(*this); + + plan->grouped_gemm_ = std::make_unique< GroupedGemm >(blas_lt_.get(), + AsHipblasOperation(cfg.trans_a), + AsHipblasOperation(cfg.trans_b), + AsHipblasDataType(cfg.type_a), + AsHipblasDataType(cfg.type_b), + AsHipblasDataType(cfg.type_c), + AsHipblasDataType(cfg.type_d), + AsHipblasComputeType(cfg.compute_type)); + auto& ggemm = plan->grouped_gemm_; + + std::vector< int64_t > m(cfg.batch_count, cfg.m), + n(cfg.batch_count, cfg.n), + k(cfg.batch_count, cfg.k), + batch_count(cfg.batch_count, 1), + lda(cfg.batch_count, cfg.lda), + ldb(cfg.batch_count, cfg.ldb), + ldc(cfg.batch_count, cfg.ldc), + ldd(cfg.batch_count, cfg.ldd), + strideA(cfg.batch_count, cfg.m * cfg.k), + strideB(cfg.batch_count, cfg.n * cfg.k), + strideC(cfg.batch_count, cfg.m * cfg.n), + strideD(cfg.batch_count, cfg.m * cfg.n); + + std::vector< GemmEpilogue > epilogue(cfg.batch_count, + GemmEpilogue{}); + std::vector< GemmInputs > inputs(cfg.batch_count); + for(int64_t i = 0; i < cfg.batch_count; i++) { + inputs[i].a = const_cast< void * >(cfg.a[i]); + inputs[i].b = const_cast< void * >(cfg.b[i]); + inputs[i].c = const_cast< void * >(cfg.c[i]); + inputs[i].d = cfg.d[i]; + inputs[i].alpha = const_cast< void * >(cfg.alpha); + inputs[i].beta = const_cast< void * >(cfg.beta); + } + + GemmProblemType problem = { + .op_a = AsHipblasOperation(cfg.trans_a), + .op_b = AsHipblasOperation(cfg.trans_b), + .type_a = AsHipblasDataType(cfg.type_a), + .type_b = AsHipblasDataType(cfg.type_b), + .type_c = AsHipblasDataType(cfg.type_c), + .type_d = AsHipblasDataType(cfg.type_d), + .type_compute = AsHipblasComputeType(cfg.compute_type) + }; + + uint64 mem_size = cfg.batch_count * sizeof(UserArguments); + { + absl::MutexLock lock(&mu_); + SE_HIPBLAS_RETURN_IF_ERROR(ggemm->setProblem(m, n, k, batch_count, + lda, ldb, ldc, ldd, strideA, strideB, strideC, strideD, + epilogue, inputs, problem)); + + plan->host_args_ = static_cast< UserArguments *>( + parent_->HostMemoryAllocate(mem_size)); + if(plan->host_args_ == nullptr) { + return xla::InternalError("Unable to allocate host memory for user args!"); + } + SE_HIPBLAS_RETURN_IF_ERROR(ggemm-> + getDefaultValueForDeviceUserArguments(plan->host_args_)); + + // NOTE: memory must be aligned by 16 bytes ?? + auto raw_mem = parent_->Allocate(mem_size, /* memory_space */ 0); + // TF_ASSIGN_OR_RETURN(auto dev_mem, allocator->Allocate(parent_->device_ordinal(), + // mem_size))); + if(raw_mem == nullptr) { + return xla::InternalError("Unable to allocate memory for grouped gemm params!"); + } + plan->device_args_ = GroupedMatmulPlan::DeviceMemoryArgs(raw_mem.opaque(), mem_size); + + if(!stream->ThenMemcpy(&plan->device_args_, plan->host_args_, mem_size).ok()) { + return xla::InternalError("Memcpy failed!"); + } + + } // end block + + //for(const auto& a : plan->host_args_) + { + // const auto& a = plan->host_args_[0]; + + // std::ostringstream os; + // for(int i = 0; i < sizeof(a.alpha); i++) { + // os << std::hex << (uint32_t)a.alpha[i]; + // } + // VLOG(0) << a.m << "," << a.n << "," << a.batch << "," << a.k << + // " alpha " << os.str() << + // " pointers: " << a.d << "," << a.c << "," << a.a << "," << a.b << + // " strides: " << a.strideD1 << "," << a.strideD2 << "," << a.strideA1 << "," << a.strideA2 << + // " activate: " << a.activationType; + } + return xla::StatusOr(std::move(plan)); +} + +auto BlasLt::GroupedMatmulPlan::GetAlgorithms( + size_t max_algorithm_count, size_t max_workspace_size) -> + xla::StatusOr> { + +// gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; ?? + + // GemmPreference gemmPref; + // gemmPref.setMaxWorkspaceBytes(max_workspace_size); + + std::vector heuristicResult; + std::vector algorithms; + + gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; + absl::MutexLock lock(&blas_lt_ref_.mu_); + + auto problem = grouped_gemm_->getProblemTypes()[0]; + //VLOG(0) << problem.op_a <<","<< + // problem.op_b<<","<< + // problem.type_a<<","<< + // problem.type_b<<","<< + // problem.type_c<<","<< + // problem.type_d<<","<< + // problem.type_compute; + + + // HIPBLAS_OP_N = 111, /**< Operate with the matrix. */ + // HIPBLAS_OP_T = 112, /**< Operate with the transpose of the matrix. */ + // HIPBLAS_OP_C = 113 /**< Operate with the conjugate transpose of the matrix. */ + + + SE_HIPBLAS_RETURN_IF_ERROR(getAllAlgos(blas_lt_ref_.blas_lt_.get(), + GemmType::HIPBLASLT_GROUPED_GEMM, + problem.op_a, + problem.op_b, + problem.type_a, + problem.type_b, + problem.type_c, + problem.type_d, + problem.type_compute, + heuristicResult)); + // SE_HIPBLAS_RETURN_IF_ERROR( + // grouped_gemm_->algoGetHeuristic(max_algorithm_count, gemmPref, + // heuristicResult)); + VLOG(2) << "Total heuristics found: " << heuristicResult.size(); + algorithms.reserve(max_algorithm_count); + for(auto& res : heuristicResult) { + size_t workspace_size = 0; + if(grouped_gemm_->isAlgoSupported(res.algo, workspace_size)) { + auto roc_algo = (const rocblaslt_matmul_algo*)&res.algo; + auto pindex = (int *)roc_algo->data; + algorithms.push_back({res.algo, workspace_size, + static_cast< blas::AlgorithmType >(*pindex)}); + if (algorithms.size() >= max_algorithm_count) break; + } + } + return algorithms; +} + +xla::Status BlasLt::GroupedMatmulPlan::SetAlgorithm( + const MatmulAlgorithm& algorithm, + ScratchAllocator * allocator) +{ + auto palgo = absl::any_cast(&algorithm.opaque_algo); + if(palgo == nullptr) { + return xla::InternalError("Wrong algorithm instance !"); + } + algorithm_ = algorithm; + void* workspace_addr = nullptr; + uint64_t workspace_size = algorithm_->workspace_size; + + if (workspace_size > 0) { + if (allocator == nullptr) { + return xla::InternalError("This algorithm requires a non-zero workspace!"); + + } + TF_ASSIGN_OR_RETURN(auto alloc, allocator->AllocateBytes(workspace_size)); + workspace_addr = gpu::GpuMemoryMutable(&alloc); + } + gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; + + absl::MutexLock lock(&blas_lt_ref_.mu_); + + // NOTE NOTE: it could be that workspace is no longer valid after + // this function returns !!!! + + + SE_HIPBLAS_RETURN_IF_ERROR(grouped_gemm_->initialize( + *palgo, workspace_addr)); + return xla::OkStatus(); +} + +xla::Status BlasLt::GroupedMatmulPlan::ExecuteOnStream(Stream *stream, + const gpu::GroupedGemmConfig& cfg, + blas::ProfileResult* profile_result) { + + if((size_t)cfg.batch_count * sizeof(UserArguments) != device_args_.size() || + !algorithm_.has_value()) + { + return xla::InternalError("GroupedGemm config mismatch or algorithm is unset!"); + } + + std::unique_ptr timer; + if (profile_result != nullptr) { + timer.reset(new gpu::GpuTimer(blas_lt_ref_.parent_)); + if (!timer->Init() || !timer->Start(gpu::AsGpuStream(stream))) { + return xla::InternalError("Unable to start gpu timer"); + } + } + + // NOTE: we can also use GPU kernel to update pointers directly + // in device mem => then memcpy won't be necessary + //for(size_t i = 0; i < device_args_.size(); i++) { + // host_args_[i].a = const_cast< void * >(cfg.a[i]); + // host_args_[i].b = const_cast< void * >(cfg.b[i]); + // host_args_[i].c = const_cast< void * >(cfg.c[i]); + // host_args_[i].d = const_cast< void * >(cfg.d[i]); + //} + + // gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; ?? + + GroupGemmUpdateArgs(gpu::AsGpuStreamValue(stream), + static_cast(device_args_.opaque()), + cfg.a, cfg.b, cfg.c, cfg.d, + cfg.batch_count); + + gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; + { + + absl::MutexLock lock(&blas_lt_ref_.mu_); + SE_HIPBLAS_RETURN_IF_ERROR(grouped_gemm_->run( + device_args_.opaque(), gpu::AsGpuStreamValue(stream))); + } // end block + + if (profile_result != nullptr) { + if (!timer->Stop(gpu::AsGpuStream(stream))) { + return xla::InternalError("Unable to stop gpu timer"); + } + // algorithm_ is alrady verified for correctness ! + profile_result->set_algorithm(algorithm_->id); + profile_result->set_is_valid(true); + profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); + } + return xla::OkStatus(); +} + +} // namespace rocm + +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cu.cc b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cu.cc new file mode 100644 index 00000000000000..b21e116819b6d5 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cu.cc @@ -0,0 +1,58 @@ +/* Copyright 2023 The OpenXLA Authors. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ +#define LEGACY_HIPBLAS_DIRECT +#include +#include "rocm/rocm_config.h" +#include "rocm/include/hipblaslt/hipblaslt-ext.hpp" +#include "rocm/include/hipblaslt/hipblaslt.h" + +namespace stream_executor { +namespace rocm { + +using namespace hipblaslt_ext; + +namespace { + +__global__ void CopyUserArgsKernel(UserArguments *dest_args, + const void **a, const void **b, const void **c, void **d, + uint32_t num_gemms) +{ + uint32_t idx = blockIdx.x*blockDim.x + threadIdx.x; + if(idx < num_gemms) { + // writing ArrayOfStructs is not optimal.. + auto arg = dest_args[idx]; + arg.a = const_cast< void *>(a[idx]); + arg.b = const_cast< void *>(b[idx]); + arg.c = const_cast< void *>(c[idx]); + arg.d = d[idx]; + //printf("idx: %d %p %p %p %p\n", idx, arg.a, arg.b, arg.c, arg.d); + } +} +} // namespace + +void GroupGemmUpdateArgs(hipStream_t stream, + UserArguments *dev_args, + //const gpu::GroupedGemmConfig& cfg + const void **a, const void **b, const void **c, void **d, + uint32_t num_gemms) { + + const uint32_t block_sz = 128, + n_blocks = (num_gemms + block_sz - 1)/block_sz; + hipLaunchKernelGGL(CopyUserArgsKernel, n_blocks, + std::min(block_sz, num_gemms), 0, + stream, + dev_args, + //static_cast< UserArguments *>(device_args_.opaque()), + a, b, c, d, num_gemms); +} +} // namespace rocm +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h new file mode 100644 index 00000000000000..3d23dc4d93b8e1 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h @@ -0,0 +1,202 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_ + +#include "rocm/rocm_config.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +//#include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" +//#include "tensorflow/compiler/xla/types.h" + +#include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h" + +namespace hipblaslt_ext { + class GroupedGemm; + struct UserArguments; +} + +namespace stream_executor { + +namespace gpu { +class GpuExecutor; +} // namespace gpu + +namespace rocm { + +class BlasLt : public gpu::BlasLt { + template + using Owned = + std::unique_ptr, hipblasStatus_t (*)(T)>; + + public: + struct MatrixLayout { + static xla::StatusOr Create(const gpu::MatrixLayout& m); + + hipDataType type() const { return datatype_; } + hipblasLtMatrixLayout_t get() const { return handle_.get(); } + + private: + MatrixLayout(hipblasLtMatrixLayout_t handle, hipDataType datatype) + : handle_(handle, hipblasLtMatrixLayoutDestroy), + datatype_(datatype) {} + + Owned handle_; + hipDataType datatype_; + }; + + class MatmulDesc { + public: + static xla::StatusOr Create( + blas::ComputationType compute_type, blas::DataType scale_type, + blas::Transpose trans_a = blas::Transpose::kNoTranspose, + blas::Transpose trans_b = blas::Transpose::kNoTranspose, + Epilogue epilogue = Epilogue::kDefault, + PointerMode pointer_mode = PointerMode::kHost); + + hipblasComputeType_t compute_type() const { return compute_type_; } + hipDataType scale_type() const { return datatype_; } + bool has_bias_epilogue() const { return has_bias_epilogue_; } + //hipblasPointerMode_t pointer_mode() const { + // return HIPBLAS_POINTER_MODE_HOST; + //} + hipblasLtMatmulDesc_t get() const { return handle_.get(); } + + private: + MatmulDesc(hipblasLtMatmulDesc_t handle, hipblasComputeType_t compute_type, + hipDataType datatype, bool bias_epilogue) + : handle_(handle, hipblasLtMatmulDescDestroy), + compute_type_(compute_type), + datatype_(datatype), + has_bias_epilogue_(bias_epilogue) {} + + Owned handle_; + hipblasComputeType_t compute_type_; + hipDataType datatype_; + bool has_bias_epilogue_; + }; + + struct MatmulPlan : public gpu::BlasLt::MatmulPlan { + MatmulPlan(const BlasLt& blas_lt_ref, MatmulDesc&& op_desc, + MatrixLayout&& a_desc, MatrixLayout&& b_desc, + MatrixLayout&& c_desc, MatrixLayout&& d_desc, + xla::complex128 alpha, double beta, bool must_swap_operands) + : blas_lt_ref_(blas_lt_ref), + op_desc_(std::move(op_desc)), + a_desc_(std::move(a_desc)), + b_desc_(std::move(b_desc)), + c_desc_(std::move(c_desc)), + d_desc_(std::move(d_desc)), + alpha_(alpha), + beta_(beta), + must_swap_operands_(must_swap_operands) {} + + ~MatmulPlan() override = default; + + xla::StatusOr> GetAlgorithms( + size_t max_algorithm_count, size_t max_workspace_size) const override; + + xla::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override; + + xla::Status ExecuteOnStream( + Stream* stream, DeviceMemoryBase a_buffer, DeviceMemoryBase b_buffer, + DeviceMemoryBase c_buffer, DeviceMemoryBase d_buffer, + DeviceMemoryBase bias_buffer, // may be null + DeviceMemoryBase aux_buffer, // may be null + DeviceMemoryBase a_scale_buffer, DeviceMemoryBase b_scale_buffer, + DeviceMemoryBase c_scale_buffer, DeviceMemoryBase d_scale_buffer, + DeviceMemoryBase d_amax_buffer, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result) const override; + + protected: + xla::Status DoMatmul(Stream* stream, const void* alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const void* beta, + DeviceMemoryBase c, DeviceMemoryBase d, + DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, + absl::optional workspace, + absl::optional scratch_allocator, + blas::ProfileResult* profile_result) const override; + + private: + const BlasLt& blas_lt_ref_; + // TODO(cjfj): Add consistency checks for types, shapes, etc.? + MatmulDesc op_desc_; + MatrixLayout a_desc_; + MatrixLayout b_desc_; + MatrixLayout c_desc_; + MatrixLayout d_desc_; + xla::complex128 alpha_; + double beta_; + bool must_swap_operands_; + absl::optional< MatmulAlgorithm > algorithm_; // selected algorithm + }; // struct MatmulPlan + + struct GroupedMatmulPlan : public gpu::BlasLt::GroupedMatmulPlan { + + friend class BlasLt; + using DeviceMemoryArgs = DeviceMemoryBase; // OwningDeviceMemory + using GroupedGemmPtr = std::unique_ptr< hipblaslt_ext::GroupedGemm >; + + GroupedMatmulPlan(const BlasLt& blas_lt); + + xla::Status SetAlgorithm(const MatmulAlgorithm& algorithm, + ScratchAllocator *scratch_allocator) override; + + xla::Status ExecuteOnStream(Stream *stream, + const gpu::GroupedGemmConfig& cfg, + blas::ProfileResult* profile_result) override; + + xla::StatusOr> GetAlgorithms( + size_t max_algorithm_count, + size_t max_workspace_size) override; + + ~GroupedMatmulPlan() override; + + private: + const BlasLt& blas_lt_ref_; + GroupedGemmPtr grouped_gemm_; + hipblaslt_ext::UserArguments *host_args_ = nullptr; + DeviceMemoryArgs device_args_; + absl::optional< MatmulAlgorithm > algorithm_; // selected algorithm + + SE_DISALLOW_COPY_AND_ASSIGN(GroupedMatmulPlan); + }; + + explicit BlasLt(gpu::GpuExecutor* parent); + + xla::Status Init() override; + + xla::StatusOr GetMatmulPlan( + const gpu::GemmConfig& cfg) const override; + + xla::StatusOr GetGroupedMatmulPlan( + Stream *stream, + const gpu::GroupedGemmConfig& config) const override; + + ~BlasLt() override = default; + + private: + gpu::GpuExecutor* parent_; + mutable absl::Mutex mu_; + Owned blas_lt_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace rocm +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_ diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc new file mode 100644 index 00000000000000..7d2b5853c466cf --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/util.h" + +namespace stream_executor { +namespace rocm { + +xla::Status ToStatus(hipblasStatus_t status, const char* prefix) { + if (status != HIPBLAS_STATUS_SUCCESS) { + return xla::InternalError("%s: HipblasLt error %d", + prefix, static_cast(status)); + } + return xla::OkStatus(); +} + +hipDataType AsHipblasDataType(blas::DataType type) { + switch (type) { + case blas::DataType::kHalf: + return HIP_R_16F; + case blas::DataType::kBF16: + return HIP_R_16BF; + case blas::DataType::kFloat: + return HIP_R_32F; + case blas::DataType::kDouble: + return HIP_R_64F; + case blas::DataType::kInt8: + return HIP_R_8I; + case blas::DataType::kInt32: + return HIP_R_32I; + case blas::DataType::kComplexFloat: + return HIP_C_32F; + case blas::DataType::kComplexDouble: + return HIP_C_64F; + default: + LOG(FATAL) << "unknown data type"; + } +} + +hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type) { + switch(type) { + case blas::ComputationType::kF16AsF32: + return HIPBLAS_COMPUTE_32F_FAST_16F; + case blas::ComputationType::kBF16AsF32: + return HIPBLAS_COMPUTE_32F_FAST_16BF; + case blas::ComputationType::kTF32AsF32: + return HIPBLAS_COMPUTE_32F_FAST_TF32; + case blas::ComputationType::kF32: + return HIPBLAS_COMPUTE_32F; + default:; + } + LOG(FATAL) << "unsupported hipblaslt computation type"; +} + +hipblasOperation_t AsHipblasOperation(blas::Transpose trans) { + switch (trans) { + case blas::Transpose::kNoTranspose: + return HIPBLAS_OP_N; + case blas::Transpose::kTranspose: + return HIPBLAS_OP_T; + case blas::Transpose::kConjugateTranspose: + return HIPBLAS_OP_C; + } +} + +} // namespace rocm +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h new file mode 100644 index 00000000000000..447aa4b2777eaf --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIP_BLAS_UTILS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIP_BLAS_UTILS_H_ + +#include + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/stream_executor/rocm/hipblaslt_wrapper.h" + +#if 0 //TF_ROCM_VERSION < 60000 +#define hipDataType hipblasDatatype_t +#define HIP_R_16F HIPBLAS_R_16F +#define HIP_R_16BF HIPBLAS_R_16B +#define HIP_R_32F HIPBLAS_R_32F +#define HIP_R_64F HIPBLAS_R_64F +#define HIP_R_8I HIPBLAS_R_8I +#define HIP_R_32I HIPBLAS_R_32I +#define HIP_C_32F HIPBLAS_C_32F +#define HIP_C_64F HIPBLAS_C_64F + +#define hipblasComputeType_t hipblasLtComputeType_t +#define HIPBLAS_COMPUTE_32F HIPBLASLT_COMPUTE_F32 +#define HIPBLAS_COMPUTE_64F HIPBLASLT_COMPUTE_F64 +#define HIPBLAS_COMPUTE_32I HIPBLASLT_COMPUTE_I32 +#endif + +namespace stream_executor { +namespace rocm { + +#define SE_HIPBLAS_RETURN_IF_ERROR(expr) \ + TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr)) + +xla::Status ToStatus(hipblasStatus_t status, const char* prefix); +hipDataType AsHipblasDataType(blas::DataType type); +hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type); +hipblasOperation_t AsHipblasOperation(blas::Transpose trans); + +} // namespace rocm +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIP_BLAS_UTILS_H_ diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hipblaslt_wrapper.h b/tensorflow/compiler/xla/stream_executor/rocm/hipblaslt_wrapper.h new file mode 100644 index 00000000000000..65ce21312f2cda --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/rocm/hipblaslt_wrapper.h @@ -0,0 +1,102 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file wraps rocsolver API calls with dso loader so that we don't need to +// have explicit linking to librocsolver. All TF hipsarse API usage should route +// through this wrapper. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ + +#define __HIP_DISABLE_CPP_FUNCTIONS__ + +#define LEGACY_HIPBLAS_DIRECT +#include "rocm/rocm_config.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" + +//#if TF_ROCM_VERSION >= 50500 +#include "rocm/include/hipblaslt/hipblaslt.h" +// #else +// #include "rocm/include/hipblaslt.h" +// #endif + +// NOTE: dynamic loader is disabled since hipblaslt-ext is linked statically! +#if 0 +namespace stream_executor { +namespace wrap { + +#ifdef PLATFORM_GOOGLE + +#define HIPBLASLT_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + return ::api_name(args...); \ + } + +#else + +#define TO_STR_(x) #x +#define TO_STR(x) TO_STR_(x) + +#define HIPBLASLT_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = port::Env::Default() -> GetSymbolFromLibrary( \ + internal::CachedDsoLoader::GetHipblasltDsoHandle() \ + .ValueOrDie(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in hipblaslt lib; dlerror: " << s.error_message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ + } + +#endif + +// clang-format off +#define FOREACH_HIPBLASLT_API(__macro) \ + __macro(hipblasLtCreate) \ + __macro(hipblasLtDestroy) \ + __macro(hipblasLtMatmulPreferenceCreate) \ + __macro(hipblasLtMatmulPreferenceSetAttribute) \ + __macro(hipblasLtMatmulPreferenceDestroy) \ + __macro(hipblasLtMatmulDescSetAttribute) \ + __macro(hipblasLtMatmulDescGetAttribute) \ + __macro(hipblasLtMatmulAlgoGetHeuristic) \ + __macro(hipblasLtMatrixLayoutCreate) \ + __macro(hipblasLtMatrixLayoutDestroy) \ + __macro(hipblasLtMatrixLayoutSetAttribute) \ + __macro(hipblasLtMatrixLayoutGetAttribute) \ + __macro(hipblasLtMatmulDescCreate) \ + __macro(hipblasLtMatmulDescDestroy) \ + __macro(hipblasLtMatmul) \ + __macro(hipblasStatusToString) +// clang-format on + +FOREACH_HIPBLASLT_API(HIPBLASLT_API_WRAPPER) + +#undef TO_STR_ +#undef TO_STR +#undef FOREACH_HIPBLASLT_API +#undef HIPBLASLT_API_WRAPPER + +} // namespace wrap +} // namespace stream_executor +#endif + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc index 600d5dea60c5a2..1228ee4c4d672e 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc @@ -105,12 +105,16 @@ bool ROCMBlas::Init() { LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret); return false; } - + if (!blas_lt_.Init().ok()) { + LOG(ERROR) << "Failed to initialize hipblasLt"; + return false; + } return true; } ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent) - : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} + : parent_(CHECK_NOTNULL(parent)), blas_(nullptr), + blas_lt_(parent) {} ROCMBlas::~ROCMBlas() { if (blas_ != nullptr) { @@ -120,17 +124,13 @@ ROCMBlas::~ROCMBlas() { } bool ROCMBlas::SetStream(Stream *stream) { - CHECK(stream != nullptr); - CHECK(AsGpuStreamValue(stream) != nullptr); - CHECK(blas_ != nullptr); - gpu::ScopedActivateExecutorContext sac{parent_}; rocblas_status ret = - wrap::rocblas_set_stream(blas_, AsGpuStreamValue(stream)); + wrap::rocblas_set_stream(blas_, + stream != nullptr ? AsGpuStreamValue(stream) : 0); if (ret != rocblas_status_success) { LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); return false; } - return true; } @@ -192,13 +192,13 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, Args... args) { absl::MutexLock lock{&mu_}; + gpu::ScopedActivateExecutorContext sac{parent_}; + CHECK(blas_ != nullptr); if (!SetStream(stream)) { return false; } - gpu::ScopedActivateExecutorContext sac{parent_}; - // set the atomics mode, leaving default to library bool allow_atomics = !OpDeterminismRequired(); rocblas_status ret; @@ -211,6 +211,7 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, } ret = rocblas_func(blas_, args...); + SetStream(nullptr); if (err_on_failure && ret != rocblas_status_success) { LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " << ToString(ret); @@ -545,8 +546,12 @@ tsl::Status ROCMBlas::DoBlasGemmWithAlgorithm( blas::AlgorithmType algorithm, blas::ComputePrecision precision, blas::ProfileResult *output_profile_result, blas::CallContext context) { - // ROCM TODO: properly implement the interface - return tsl::errors::Internal("Not implemented on ROCm"); + + if (!(type_a == type_b && type_b == type_c)) { + return tsl::errors::Internal("Mixed-precision is NYI!"); + } + return DoBlasGemm(stream, transa, transb, m, n, k, type_a, + alpha, a, lda, b, ldb, beta, c, ldc, precision, context); } tsl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( @@ -559,8 +564,13 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( blas::AlgorithmType algorithm, blas::ComputePrecision precision, blas::ProfileResult *output_profile_result, blas::CallContext context) { - // ROCM TODO: properly implement the interface - return tsl::errors::Internal("Not implemented on ROCm"); + + if (!(type_a == type_b && type_b == type_c)) { + return tsl::errors::Internal("Mixed-precision is NYI!"); + } + return DoBlasGemmStridedBatched(stream, transa, transb, m, + n, k, type_a, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count, precision, context); } bool ROCMBlas::GetBlasGemmAlgorithms( diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.h index 7435a7a4b617e1..579e154441cdd8 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.h @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" #include "tensorflow/compiler/xla/stream_executor/temporary_device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h" + namespace stream_executor { class Stream; @@ -94,6 +96,10 @@ class ROCMBlas : public blas::BlasSupport { TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES + gpu::BlasLt *GetBlasLt() override { + return &blas_lt_; + } + private: // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. // @@ -190,6 +196,7 @@ class ROCMBlas : public blas::BlasSupport { // rocBLAS library handle on the device. rocblas_handle blas_ ABSL_GUARDED_BY(mu_); + rocm::BlasLt blas_lt_; SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); }; diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc index 28ffe2375aec1e..47dc454348905e 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc @@ -397,23 +397,34 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { } /* static */ tsl::Status GpuDriver::LaunchKernel( - GpuContext* context, absl::string_view kernel_name, hipFunction_t function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, unsigned int shared_mem_bytes, - GpuStreamHandle stream, void** kernel_params, void** extra) { + GpuContext* context, hipFunction_t function, unsigned int grid_dim_x, + unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, + unsigned int block_dim_y, unsigned int block_dim_z, + unsigned int shared_mem_bytes, GpuStreamHandle stream, void** kernel_params, + void** extra) { ScopedActivateContext activation{context}; - VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x + VLOG(2) << "launching kernel: " << function << "; gdx: " << grid_dim_x << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z << " bdx: " << block_dim_x << " bdy: " << block_dim_y << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes; - RETURN_IF_ROCM_ERROR(wrap::hipModuleLaunchKernel( - function, grid_dim_x, grid_dim_y, grid_dim_z, - block_dim_x, block_dim_y, block_dim_z, - shared_mem_bytes, stream, kernel_params, extra), - "Failed to launch ROCm kernel: ", kernel_name, - " with block dimensions: ", block_dim_x, "x", - block_dim_y, "x", block_dim_z); + + // for in-process kernels we use non-null kernel_params: + auto res = hipSuccess; + if (kernel_params != nullptr) { + res = wrap::hipLaunchKernel((const void*)function, + dim3(grid_dim_x, grid_dim_y, grid_dim_z), + dim3(block_dim_x, block_dim_y, block_dim_z), + kernel_params, shared_mem_bytes, stream); + } else { + res = wrap::hipModuleLaunchKernel( + function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, + block_dim_z, shared_mem_bytes, stream, nullptr, extra); + } + + if (res != hipSuccess) { + return tsl::errors::Internal( + absl::StrCat("Failed to launch ROCM kernel: ", ToString(res))); + } VLOG(2) << "successfully launched kernel"; return tsl::OkStatus(); } diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h index ad62ffdf4551f0..2d46899ae37e1a 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -97,6 +97,7 @@ namespace wrap { __macro(hipHostUnregister) \ __macro(hipInit) \ __macro(hipLaunchHostFunc) \ + __macro(hipLaunchKernel) \ __macro(hipMalloc) \ __macro(hipMemGetAddressRange) \ __macro(hipMemGetInfo) \ diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc index bd2eb602aa1c7b..ddbdd8f4a58632 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -260,23 +260,39 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); } kernel_to_gpu_binary_[kernel] = hsaco; + } else if (spec.has_in_process_symbol()) { + kernelname = &spec.in_process_symbol().kernelname(); + void* symbol = spec.in_process_symbol().symbol(); + + VLOG(1) << "Resolve ROCM kernel " << *kernelname + << " from symbol pointer: " << symbol; + + *rocm_kernel->gpu_function_ptr() = static_cast(symbol); + rocm_kernel->SetInProcessSymbol(true); } else { return tsl::errors::Internal("No method of loading ROCM kernel provided"); } - VLOG(2) << "getting function " << *kernelname << " from module " << module; - if (!GpuDriver::GetModuleFunction(context_, module, kernelname->c_str(), + // If we resolved kernel from a symbol pointer, there is no need to load it + // from a module, as ROCm runtime did that automatically for us. + if (!spec.has_in_process_symbol()){ + VLOG(2) << "getting function " << *kernelname << " from module " << module; + if (!GpuDriver::GetModuleFunction(context_, module, kernelname->c_str(), rocm_kernel->gpu_function_ptr())) { - return tsl::errors::Internal("Failed getting module function"); + return tsl::errors::Internal("Failed getting module function"); + } } // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the ROCM API. rocm_kernel->set_arity(spec.arity()); - KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); + // unable to get kernel metadata for in-process kernel + if (!spec.has_in_process_symbol()) { + KernelMetadata kernel_metadata; + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); + kernel->set_metadata(kernel_metadata); + } kernel->set_name(*kernelname); return tsl::OkStatus(); } @@ -330,17 +346,25 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, VLOG(2) << "*(arg.address): " << reinterpret_cast( *static_cast(arg.address)); - kernargs.push_back( + // in-process kernel launches need a list of pointers to the params + kernargs.push_back(rocm_kernel->IsInProcessSymbol() ? + const_cast(arg.address) : reinterpret_cast(*static_cast(arg.address))); } + if(rocm_kernel->IsInProcessSymbol()) { + return GpuDriver::LaunchKernel( + GetGpuContext(stream), hipfunc, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + args.number_of_shared_bytes(), hipstream, kernargs.data(), nullptr); + } size_t size = sizeof(void*) * kernargs.size(); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernargs.data(), HIP_LAUNCH_PARAM_BUFFER_SIZE, &size, HIP_LAUNCH_PARAM_END}; return GpuDriver::LaunchKernel( - GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + GetGpuContext(stream), hipfunc, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, args.number_of_shared_bytes(), hipstream, nullptr, (void**)&config); } diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index f9d3f37a73b63c..2080bff5b004dc 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -33,6 +33,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/tsl/platform/stacktrace.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" namespace stream_executor { @@ -1427,6 +1429,117 @@ Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, x, incx, beta, y, incy); } +template +tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64 k, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + blas::ComputePrecision precision, + blas::CallContext context) { + InputType alpha{1.0}; + InputType beta{0.0}; + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, precision, context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM(INPUT_TYPE) \ + template tsl::Status Stream::ThenBlasGemm( \ + blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64 n, uint64 k, \ + const DeviceMemory& a, int lda, \ + const DeviceMemory& b, int ldb, \ + DeviceMemory* c, int ldc, \ + blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM(float) +INSTANTIATE_THEN_BLAS_GEMM(double) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM(std::complex) +INSTANTIATE_THEN_BLAS_GEMM(std::complex) + +#undef INSTANTIATE_THEN_BLAS_GEMM + +template +tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64 k, ConstantType alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + ConstantType beta, DeviceMemory *c, + int ldc, blas::ComputePrecision precision, + blas::CallContext context) { + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + static_assert( + detail::is_any_of, std::complex>(), + "Input can be half, bf16, float, double, std::complex or " + "std::complex"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::half, constant has to be either " + "Eigen::half or float"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::bfloat16, constant has to be either " + "Eigen::bfloat16 or float"); + static_assert( + detail::is_any_of(), + "If input is not Eigen::half, constant and input types have to match"); + blas::BlasSupport *blas = parent()->AsBlas(); + if (!blas) { + return tsl::errors::Internal( + "Attempting to perform BLAS operation using " + "StreamExecutor without BLAS support"); + } + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return blas->DoBlasGemm(this, transa, transb, m, n, k, + blas::ToDataType::value, alpha_ptr, a, + lda, b, ldb, beta_ptr, c, ldc, precision, + context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM(INPUT_TYPE, CONSTANT_TYPE) \ + template tsl::Status Stream::ThenBlasGemm( \ + blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64 n, uint64 k, CONSTANT_TYPE alpha, \ + const DeviceMemory& a, int lda, \ + const DeviceMemory& b, int ldb, \ + CONSTANT_TYPE beta, DeviceMemory* c, int ldc, \ + blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM(float, float) +INSTANTIATE_THEN_BLAS_GEMM(double, double) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half, Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half, float) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, float) + +#undef INSTANTIATE_THEN_BLAS_GEMM + namespace { // Like ThenBlasImpl, except this expects the last argument of blas_func to be a // blas::ProfileResult*. This functor doesn't put the stream into an error @@ -1590,6 +1703,12 @@ Stream &Stream::ThenBlasGemmBatched( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); + return *this; + } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1643,6 +1762,12 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); + return *this; + } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1675,6 +1800,12 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); + return *this; + } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1706,6 +1837,12 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, blas::CallContext context) { + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); + return *this; + } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1740,6 +1877,12 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, blas::CallContext context) { + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); + return *this; + } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h index 53a1be6ec30d13..d3d30bc0257752 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.h +++ b/tensorflow/compiler/xla/stream_executor/stream.h @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h" + namespace stream_executor { namespace host { @@ -898,24 +899,7 @@ class Stream { const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, blas::ComputePrecision precision, - blas::CallContext context) { - InputType alpha{1.0}; - InputType beta{0.0}; - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, precision, context); - } - - // TODO(parkers): Update all callers to pass kDefaultComputePrecision. - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - DeviceMemory *c, int ldc, - blas::CallContext context) { - return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc, - blas::kDefaultComputePrecision,context); - } + blas::CallContext context); template tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, @@ -924,49 +908,7 @@ class Stream { const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, int ldc, blas::ComputePrecision precision, - blas::CallContext context) { - static_assert( - detail::is_any_of, std::complex>(), - "Input can be half, bf16, float, double, std::complex or " - "std::complex"); - static_assert(!std::is_same_v || - detail::is_any_of(), - "If input is Eigen::half, constant has to be either " - "Eigen::half or float"); - static_assert( - detail::is_any_of(), - "If input is not Eigen::half, constant and input types have to match"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemm(this, transa, transb, m, n, k, - blas::ToDataType::value, alpha_ptr, a, - lda, b, ldb, beta_ptr, c, ldc, precision, - context); - } - - // TODO(parkers): Update all callers to pass kDefaultComputePrecision. - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, ConstantType alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - ConstantType beta, DeviceMemory *c, - int ldc, blas::CallContext context) { - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, blas::kDefaultComputePrecision, context); - } + blas::CallContext context); template tsl::Status ThenBlasGemmWithAlgorithm( @@ -980,9 +922,9 @@ class Stream { OutputType alpha{1}; OutputType beta{0}; return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, computation_type, - algorithm, blas::kDefaultComputePrecision, - output_profile_result, context); + ldb, beta, c, ldc, computation_type, + algorithm, blas::kDefaultComputePrecision, + output_profile_result, context); } template diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h index ff785541ae2b61..b6e2bb1c9aa29a 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h @@ -475,6 +475,11 @@ class StreamExecutor { absl::string_view kernel_name, absl::string_view ptx, absl::Span cubin_data); + // Same as above but for in-process kernels + template + tsl::StatusOr>> CreateTypedKernel( + absl::string_view kernel_name, void *symbol); + // Warning: use Stream::ThenLaunch instead, this method is not for general // consumption. However, this is the only way to launch a kernel for which // the type signature is only known at runtime; say, if an application @@ -834,6 +839,18 @@ StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, return std::move(kernel_base); } +template +inline tsl::StatusOr>> +StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, + void *symbol) { + auto kernel_base = absl::make_unique>(this); + MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); + loader_spec.AddInProcessSymbol(symbol, kernel_name); + + TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get())); + return std::move(kernel_base); +} + template inline DeviceMemory StreamExecutor::AllocateArray(uint64_t element_count, int64_t memory_space) { diff --git a/tensorflow/compiler/xla/tests/matmul_test.cc b/tensorflow/compiler/xla/tests/matmul_test.cc index b28655a63ba7b3..ff297afb5548aa 100644 --- a/tensorflow/compiler/xla/tests/matmul_test.cc +++ b/tensorflow/compiler/xla/tests/matmul_test.cc @@ -30,11 +30,7 @@ class MatmulTestWithCublas : public HloTestBase, public: DebugOptions GetDebugOptionsForTest() override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); -#if TENSORFLOW_USE_ROCM - debug_options.set_xla_gpu_enable_cublaslt(false); -#else debug_options.set_xla_gpu_enable_cublaslt(use_cublas_lt_); -#endif return debug_options; } diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 5bb0829a48c90a..e70121551c711e 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -747,10 +747,6 @@ tsl::StatusOr LhloDialectEmitter::EmitCustomCallOp( return EmitCublasLtMatmul(custom_call_instr); } - if (xla::gpu::IsCublasLtMatmulF8(*instr)) { - return EmitCublasLtMatmulF8(custom_call_instr); - } - if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) { return EmitDnnConvolution(custom_call_instr); } @@ -948,19 +944,19 @@ tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmul( TF_ASSIGN_OR_RETURN( bool has_vector_bias, - xla::gpu::cublas_lt::EpilogueAddsVectorBias(config.epilogue())); + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); TF_ASSIGN_OR_RETURN( bool has_aux_output, - xla::gpu::cublas_lt::EpilogueHasAuxiliaryOutput(config.epilogue())); + xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(config.epilogue())); TF_RET_CHECK(custom_call->operand_count() == 2 + int{has_matrix_bias} + int{has_vector_bias}); xla::ShapeIndex output_index = - has_aux_output ? xla::ShapeIndex{0} : xla::ShapeIndex{}; + has_aux_output ? xla::ShapeIndex{0, 0} : xla::ShapeIndex{0}; - llvm::SmallVector operands; + llvm::SmallVector operands; TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); if (has_matrix_bias) { @@ -976,15 +972,18 @@ tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmul( } if (has_aux_output) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, {1})); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, + xla::ShapeIndex{0, 1})); } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, + xla::ShapeIndex{1})); auto op = CreateOpWithoutAttrs(custom_call, operands); SetMatmulAttributes(op, config, builder_); int32_t operand_sizes[] = { - 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_aux_output ? 1 : 0}; + 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_aux_output ? 1 : 0, 1}; op->setAttr(op.getOperandSegmentSizeAttr(), builder_.getDenseI32ArrayAttr(operand_sizes)); @@ -1013,7 +1012,7 @@ tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmulF8( TF_ASSIGN_OR_RETURN( bool has_vector_bias, - xla::gpu::cublas_lt::EpilogueAddsVectorBias(config.epilogue())); + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); bool has_damax = custom_call->shape().IsTuple(); xla::ShapeIndex output_index = diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 0e0cf42af5bdb1..91c5e9398c44f2 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -517,7 +517,33 @@ message DebugOptions { bool xla_gpu_triton_gemm_any = 190; - // Next id: 210 + // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of + // elements of both matrices in non-batch dimensions to be considered for a + // rewrite). + int64 xla_gpu_gemm_rewrite_size_threshold = 210; + + // File to write autotune results to. It will be a binary file unless the name + // ends with .txt or .textproto. Warning: The results are written at every + // compilation, possibly multiple times per process. This only works on CUDA. + string xla_gpu_dump_autotune_results_to = 211; + + // File to load autotune results from. It will be considered a binary file + // unless the name ends with .txt or .textproto. At most one loading will + // happen during the lifetime of one process, even if the first one is + // unsuccessful or different file paths are passed here. This only works on + // CUDA. + string xla_gpu_load_autotune_results_from = 212; + + // Relative precision for comparing different GEMM solutions + float xla_gpu_autotune_gemm_rtol = 213; + + // Higher values make it more likely that we'll catch an out-of-bounds read or + // write. Smaller values consume less memory during autotuning. Note that a + // fused cudnn conv has up to 6 total buffers (4 inputs, 1 output, and 1 + // scratch), so this can be multiplied by quite a lot. + int64 xla_gpu_redzone_padding_bytes = 214; + + // Next id: 215 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 77ec948af32c6e..794a80cbf06fdb 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -32,6 +32,7 @@ HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}' ROCR_RUNTIME_PATH = '%{rocr_runtime_path}' ROCR_RUNTIME_LIBRARY = '%{rocr_runtime_library}' VERBOSE = '%{crosstool_verbose}'=='1' +ROCM_AMDGPU_TARGETS = '%{rocm_amdgpu_targets}' def Log(s): print('gpus/crosstool: {0}'.format(s)) @@ -93,6 +94,27 @@ def GetHostCompilerOptions(argv): return opts +def GetHipccOptions(argv): + """Collect the -hipcc_options values from argv. + Args: + argv: A list of strings, possibly the argv passed to main(). + Returns: + The string that can be passed directly to hipcc. + """ + + parser = ArgumentParser() + parser.add_argument('--offload-arch', nargs='*', action='append') + # TODO find a better place for this + parser.add_argument('-gline-tables-only', action='store_true') + + args, _ = parser.parse_known_args(argv) + + hipcc_opts = ' -gline-tables-only ' if args.gline_tables_only else '' + if args.offload_arch: + hipcc_opts = hipcc_opts + ' '.join(['--offload-arch=' + a for a in sum(args.offload_arch, [])]) + + return hipcc_opts + def system(cmd): """Invokes cmd with os.system(). @@ -122,6 +144,7 @@ def InvokeHipcc(argv, log=False): """ host_compiler_options = GetHostCompilerOptions(argv) + hipcc_compiler_options = GetHipccOptions(argv) opt_option = GetOptionValue(argv, 'O') m_options = GetOptionValue(argv, 'm') m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) @@ -160,7 +183,7 @@ def InvokeHipcc(argv, log=False): srcs = ' '.join(src_files) out = ' -o ' + out_file[0] - hipccopts = ' ' + hipccopts = hipcc_compiler_options + ' ' # In hip-clang environment, we need to make sure that hip header is included # before some standard math header like is included in any source. # Otherwise, we get build error. @@ -214,8 +237,10 @@ def main(): parser.add_argument('-pass-exit-codes', action='store_true') args, leftover = parser.parse_known_args(sys.argv[1:]) - if VERBOSE: print('PWD=' + os.getcwd()) - if VERBOSE: print('HIPCC_ENV=' + HIPCC_ENV) + if VERBOSE: + print('PWD=' + os.getcwd()) + print('HIPCC_ENV=' + HIPCC_ENV) + print('ROCM_AMDGPU_TARGETS= ' + ROCM_AMDGPU_TARGETS) if args.x and args.x[0] == 'rocm': # compilation for GPU objects diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 716c63697c7a0c..2b4595bb222885 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -48,6 +48,14 @@ def if_rocm_is_configured(x): return select({"//conditions:default": x}) return select({"//conditions:default": []}) +def rocm_hipblaslt(): + return %{rocm_is_configured} and %{rocm_hipblaslt} + +def if_rocm_hipblaslt(x): + if %{rocm_is_configured} and (%{rocm_hipblaslt} == "True"): + return select({"//conditions:default": x}) + return select({"//conditions:default": []}) + def rocm_library(copts = [], **kwargs): """Wrapper over cc_library which adds default ROCm options.""" native.cc_library(copts = rocm_default_copts() + copts, **kwargs) diff --git a/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/gpus/rocm/rocm_config.h.tpl index ec26b00a5b5127..20506f64b2b9c6 100644 --- a/third_party/gpus/rocm/rocm_config.h.tpl +++ b/third_party/gpus/rocm/rocm_config.h.tpl @@ -21,5 +21,6 @@ limitations under the License. #define TF_ROCM_VERSION %{rocm_version_number} #define TF_MIOPEN_VERSION %{miopen_version_number} #define TF_HIPRUNTIME_VERSION %{hipruntime_version_number} +#define TF_HIPBLASLT %{hipblaslt_flag} #endif // ROCM_ROCM_CONFIG_H_ diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index e74573640295d8..a568b1c5f517e8 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -359,6 +359,10 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ if int(rocm_config.rocm_version_number) >= 40500: libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path))) libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path))) + + # hipblaslt may be absent even in versions of ROCm where it exists + # (it is not installed by default in some containers). Autodetect. + libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path))) return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) def _exec_find_rocm_config(repository_ctx, script_path): @@ -469,6 +473,7 @@ def _create_dummy_repository(repository_ctx): "%{rocm_extra_copts}": "[]", "%{rocm_gpu_architectures}": "[]", "%{rocm_version_number}": "0", + "%{rocm_hipblaslt}": "False", }, ) _tpl( @@ -487,6 +492,7 @@ def _create_dummy_repository(repository_ctx): "%{roctracer_lib}": _lib_name("roctracer64"), "%{rocsolver_lib}": _lib_name("rocsolver"), "%{hipsolver_lib}": _lib_name("hipsolver"), + "%{hipblaslt_lib}": _lib_name("hipblaslt"), "%{copy_rules}": "", "%{rocm_headers}": "", }, @@ -503,6 +509,7 @@ def _create_dummy_repository(repository_ctx): "rocm:rocm_config.h", { "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, + "%{hipblaslt_flag}": "0", }, "rocm/rocm/rocm_config.h", ) @@ -642,6 +649,7 @@ def _create_local_rocm_repository(repository_ctx): ], )) + have_hipblaslt = "1" if rocm_libs["hipblaslt"] != None else "0" # Set up BUILD file for rocm/ repository_ctx.template( @@ -655,6 +663,7 @@ def _create_local_rocm_repository(repository_ctx): ), "%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets), "%{rocm_version_number}": str(rocm_version_number), + "%{rocm_hipblaslt}": "True" if rocm_libs["hipblaslt"] != None else "False", }, ) @@ -674,6 +683,9 @@ def _create_local_rocm_repository(repository_ctx): hiprand_include + rocrand_include), } + if rocm_libs["hipblaslt"] != None: + repository_dict["%{hipblaslt_lib}"] = rocm_libs["hipblaslt"].file_name + if rocm_version_number >= 40500: repository_dict["%{hipsolver_lib}"] = rocm_libs["hipsolver"].file_name repository_dict["%{hipblas_lib}"] = rocm_libs["hipblas"].file_name @@ -746,6 +758,9 @@ def _create_local_rocm_repository(repository_ctx): "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), }, ) @@ -762,6 +777,7 @@ def _create_local_rocm_repository(repository_ctx): "%{rocm_version_number}": rocm_config.rocm_version_number, "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, + "%{hipblaslt_flag}": have_hipblaslt, }, )