From 25d7a4fa0818a9b105fd4b8a31a536cd4917e59b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Aug 2024 17:30:16 -0700 Subject: [PATCH] [CUDA] Update benchmark_mha.py to capture debug info to identify sdpa kernel (#21804) Use debug info to identify sdpa kernel actually used, and show it in the output of benchmark_mha.py. This updated benchmark script was used to get the benchmark results in https://github.com/microsoft/onnxruntime/pull/21629. (1) Change the output format of debug info to output like SdpaKernel=* (2) Add a step to capture stdout from onnxruntime session, and use regular expression to parse SdpaKernel=* from the captured text. Other minor changes: (1) Set different default repeats during benchmark: 100 for CPU; and 10000 for CUDA. (2) Fix PrintTensorByDims used in console dumper: if it is not enabled, do not dump tensor. (3) Update some comments ### Motivation and Context Sometime, we will use fallback for a sdpa_kernel. It could confuse user unless we can tell exact kernel is used in benchmark. --- .../contrib_ops/cpu/utils/console_dumper.h | 6 +- .../cuda/bert/attention_kernel_options.cc | 54 +++------ .../bert/cudnn_fmha/cudnn_flash_attention.cu | 2 +- .../cuda/bert/multihead_attention.cc | 1 - .../contrib_ops/cuda/bert/packed_attention.cc | 4 +- .../cuda/bert/packed_multihead_attention.cc | 2 +- .../test/python/transformers/benchmark_mha.py | 106 ++++++++++++++++-- 7 files changed, 121 insertions(+), 54 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 12cbc5049a02a..9ebc44f4411eb 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -53,7 +53,11 @@ void PrintTensorByDims(const TConsoleDumper* dumper, const char* name, const T* tensor, gsl::span& dims) { - if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) { + if (!dumper->IsEnabled()) { + return; + } + + if ((tensor == nullptr || dims.size() == 0)) { std::cout << std::string(name) << " is None" << std::endl; return; } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index b2e80cb5035cb..7d21451df5b86 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -128,45 +128,23 @@ void AttentionKernelDebugInfo::Print(const char* operator_name, sstream << " DataType=fp32"; } + sstream << " SdpaKernel="; if (use_flash_attention.has_value() && use_flash_attention.value()) { - sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value()); - } - - if (use_efficient_attention.has_value() && use_efficient_attention.value()) { - sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value()); - } - - if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { - sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value()); - } - - if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) { - sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value()); - } - - if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) { - sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value()); - } - - if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) { - sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value()); - } - - if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) { - sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value()); - } - - bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) || - (use_efficient_attention.has_value() && use_efficient_attention.value()) || - (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) || - (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) || - (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) || - (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) || - (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()); - - // Fall back to unfused when no fused kernel is enabled. - if (!use_fused) { - sstream << " MATH=1"; + sstream << "FLASH_ATTENTION"; + } else if (use_efficient_attention.has_value() && use_efficient_attention.value()) { + sstream << "EFFICIENT_ATTENTION"; + } else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { + sstream << "TRT_FUSED_ATTENTION"; + } else if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) { + sstream << "CUDNN_FLASH_ATTENTION"; + } else if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) { + sstream << "TRT_FLASH_ATTENTION"; + } else if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) { + sstream << "TRT_CROSS_ATTENTION"; + } else if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) { + sstream << "TRT_CAUSAL_ATTENTION"; + } else { + sstream << "MATH"; } // Output text in Cyan color to make it easier to spot. diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu index 426b105dff8db..f334b72e70a34 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu @@ -314,7 +314,7 @@ struct BytesHash { }; // Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe. -// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size. +// TODO(tianleiwu): since the key includes sequence lengths, we may want to limit the cache size. thread_local std::unordered_map, BytesHash > mha_graph_cache; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 2ad8bc4015a47..0960a9efe7699 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -233,7 +233,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_self_attention_ && - fused_cross_attention_kernel == nullptr && nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 0e5300f32da3c..f486d08244547 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -111,7 +111,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, // Abbreviation and Meanings: // T: token_count // B: batch_size - // S: sequence_length (input sequence length of query) + // S: sequence_length // N: num_heads // H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size // H_v: v_head_size @@ -125,7 +125,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // token_offset : (B, S) // cu_seq_len_shape : (B + 1) - // attention_bias : (B, N, S, S), (1, N, S, S) or NULL + // attention_bias : (B or 1, N or 1, S, S) or NULL const auto& input_dims = input_shape.GetDims(); if (input_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 72a4c776d4fce..b0c3a28df2336 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -68,7 +68,7 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, // Input 'value': None // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) - // Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None + // Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None // Output 'output': (token_count, v_hidden_size) const auto& query_dims = query_shape.GetDims(); diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 4cc5ce4201ea1..2a3541db4c9b5 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -18,7 +18,10 @@ import math import os import platform +import re import statistics +import sys +import threading import time from contextlib import nullcontext from datetime import datetime @@ -771,6 +774,72 @@ def get_compute_capability(): return sm +class CaptureStdout: + def __init__(self): + self.fd = sys.stdout.fileno() + self.chunk_size = 1024 + self.output = b"" + + def _capture(self): + chunks = [] + while chunk := os.read(self._pipe_reader, self.chunk_size): + chunks.append(chunk) + self.output = b"".join(chunks) + + def __enter__(self): + self._duped_fd = os.dup(self.fd) + self._pipe_reader, pipe_writer = os.pipe() + os.dup2(pipe_writer, self.fd) + os.close(pipe_writer) + self._capture_thread = threading.Thread(target=self._capture) + self._capture_thread.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + os.close(self.fd) + self._capture_thread.join() + os.close(self._pipe_reader) + os.dup2(self._duped_fd, self.fd) + os.close(self._duped_fd) + + +def sdpa_kernel_from_debug_info( + config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel, sess_options: SessionOptions +): + os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1" + captured_text = None + try: + with CaptureStdout() as captured: + session = create_session(config, sess_options, attention_kernel=attention_kernel) + input_dict = config.random_inputs() + session.infer(input_dict) + captured_text = captured.output.decode() + except Exception as e: + print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}") + finally: + os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" + + if captured_text is not None: + m = re.search("SdpaKernel=(?P[A-Z_]+)", captured_text) + if m is not None: + name = m.group("kernel") + kernel_names = { + "FLASH_ATTENTION": "ort:flash", + "EFFICIENT_ATTENTION": "ort:efficient", + "CUDNN_FLASH_ATTENTION": "ort:cudnn", + "MATH": "ort:math", + "TRT_FUSED_ATTENTION": "ort:trt_fmha", + "TRT_FLASH_ATTENTION": "ort:trt_flash", + "TRT_CROSS_ATTENTION": "ort:trt_cross", + "TRT_CAUSAL_ATTENTION": "ort:trt_causal", + } + return kernel_names[name] + else: + print("Failed to get sdpa kernel from debug info:", captured_text) + + return None + + def run_tflops_test( csv_writer: csv.DictWriter, args: argparse.Namespace, @@ -809,7 +878,9 @@ def run_tflops_test( backends = [SdpaKernel.DEFAULT] configs = get_test_configs(args) - print("\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") + print( + "\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tsdpa_kernel\trequest_kernel" + ) for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: @@ -836,14 +907,13 @@ def run_tflops_test( for attention_kernel in backends: sess_options = SessionOptions() sess_options.intra_op_num_threads = intra_op_num_threads - session = create_session(config, sess_options, attention_kernel=attention_kernel) if use_gpu: - kernel = get_gpu_kernel_name(attention_kernel) + request_kernel = get_gpu_kernel_name(attention_kernel) else: - kernel = get_cpu_kernel_name(config) + request_kernel = get_cpu_kernel_name(config) - if "math" in kernel: + if "math" in request_kernel: # Skip large sequence length for Unfused kernel to avoid OOM. if not enable_unfused: if config.verbose: @@ -856,13 +926,23 @@ def run_tflops_test( print(f"skip input_format for {vars(config)}") continue + if use_gpu: + actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options) + if actual_kernel is None: + print(f"Warning: skip {config} since kernel from debug info is None") + continue + else: + # CPU has no debug info for now. + actual_kernel = request_kernel + + session = create_session(config, sess_options, attention_kernel=attention_kernel) input_dict = config.random_inputs() # warm up session try: _ = measure_latency(session, input_dict) except Exception as e: - print(f"Failed to run {kernel=} for {config=}. Exception: {e}") + print(f"Failed to run {request_kernel=} for {config=}. Exception: {e}") continue latency_list = [] @@ -898,7 +978,8 @@ def run_tflops_test( "intra_op_num_threads": intra_op_num_threads, "average_latency": average_latency, "tflops": speed, - "kernel": kernel, + "request_kernel": request_kernel, + "kernel": actual_kernel, } csv_writer.writerow(row) @@ -906,7 +987,7 @@ def run_tflops_test( print( f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t" f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}" ) @@ -979,7 +1060,7 @@ def run_torch_test( print( f"{input_format}\t{causal}\t{False}\t{batch_size}\t" f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" - f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}" + f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}\t{backend_name}" ) row = { "use_gpu": use_gpu, @@ -997,6 +1078,7 @@ def run_torch_test( "intra_op_num_threads": torch.get_num_threads(), "average_latency": torch_latency, "tflops": speed, + "request_kernel": backend_name, "kernel": backend_name, } csv_writer.writerow(row) @@ -1030,6 +1112,7 @@ def run_tflops_tests(args): "intra_op_num_threads", "average_latency", "tflops", + "request_kernel", "kernel", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) @@ -1224,7 +1307,7 @@ def _parse_arguments(): "--repeats", required=False, type=int, - default=100, + default=0, help="number of repeats for performance test", ) @@ -1269,6 +1352,9 @@ def _parse_arguments(): args = _parse_arguments() print(f"arguments:{args}") + if args.repeats == 0: + args.repeats = 10000 if args.use_gpu else 100 + if args.use_gpu: assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available()