Skip to content

Commit

Permalink
[CUDA] Update benchmark_mha.py to capture debug info to identify sdpa…
Browse files Browse the repository at this point in the history
… kernel (#21804)

Use debug info to identify sdpa kernel actually used, and show it in the
output of benchmark_mha.py. This updated benchmark script was used to
get the benchmark results in
#21629.
(1) Change the output format of debug info to output like SdpaKernel=*
(2) Add a step to capture stdout from onnxruntime session, and use
regular expression to parse SdpaKernel=* from the captured text.

Other minor changes:
(1) Set different default repeats during benchmark: 100 for CPU; and
10000 for CUDA.
(2) Fix PrintTensorByDims used in console dumper: if it is not enabled,
do not dump tensor.
(3) Update some comments

### Motivation and Context

Sometime, we will use fallback for a sdpa_kernel. It could confuse user
unless we can tell exact kernel is used in benchmark.
  • Loading branch information
tianleiwu authored Aug 22, 2024
1 parent 44a3923 commit 25d7a4f
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 54 deletions.
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cpu/utils/console_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ void PrintTensorByDims(const TConsoleDumper* dumper,
const char* name,
const T* tensor,
gsl::span<const int64_t>& dims) {
if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
if (!dumper->IsEnabled()) {
return;
}

if ((tensor == nullptr || dims.size() == 0)) {
std::cout << std::string(name) << " is None" << std::endl;
return;
}
Expand Down
54 changes: 16 additions & 38 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,45 +128,23 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
sstream << " DataType=fp32";
}

sstream << " SdpaKernel=";
if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
}

if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
}

if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
}

if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
}

if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
}

if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
}

if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
}

bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());

// Fall back to unfused when no fused kernel is enabled.
if (!use_fused) {
sstream << " MATH=1";
sstream << "FLASH_ATTENTION";
} else if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << "EFFICIENT_ATTENTION";
} else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << "TRT_FUSED_ATTENTION";
} else if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << "CUDNN_FLASH_ATTENTION";
} else if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << "TRT_FLASH_ATTENTION";
} else if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << "TRT_CROSS_ATTENTION";
} else if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << "TRT_CAUSAL_ATTENTION";
} else {
sstream << "MATH";
}

// Output text in Cyan color to make it easier to spot.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ struct BytesHash {
};

// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size.
// TODO(tianleiwu): since the key includes sequence lengths, we may want to limit the cache size.
thread_local
std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner =
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_fused_self_attention_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == attention_bias &&
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
nullptr == past_key && nullptr == present_key &&
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// Abbreviation and Meanings:
// T: token_count
// B: batch_size
// S: sequence_length (input sequence length of query)
// S: sequence_length
// N: num_heads
// H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size
// H_v: v_head_size
Expand All @@ -125,7 +125,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// bias (Q/K/V) : (D + D + D_v)
// token_offset : (B, S)
// cu_seq_len_shape : (B + 1)
// attention_bias : (B, N, S, S), (1, N, S, S) or NULL
// attention_bias : (B or 1, N or 1, S, S) or NULL
const auto& input_dims = input_shape.GetDims();
if (input_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Status PackedMultiHeadAttention<T>::CheckInputs(const TensorShape& query_shape,
// Input 'value': None
// Input 'token_offset': (batch_size, sequence_length)
// Input 'cumulative_sequence_length': (batch_size + 1)
// Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None
// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None
// Output 'output': (token_count, v_hidden_size)

const auto& query_dims = query_shape.GetDims();
Expand Down
106 changes: 96 additions & 10 deletions onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import math
import os
import platform
import re
import statistics
import sys
import threading
import time
from contextlib import nullcontext
from datetime import datetime
Expand Down Expand Up @@ -771,6 +774,72 @@ def get_compute_capability():
return sm


class CaptureStdout:
def __init__(self):
self.fd = sys.stdout.fileno()
self.chunk_size = 1024
self.output = b""

def _capture(self):
chunks = []
while chunk := os.read(self._pipe_reader, self.chunk_size):
chunks.append(chunk)
self.output = b"".join(chunks)

def __enter__(self):
self._duped_fd = os.dup(self.fd)
self._pipe_reader, pipe_writer = os.pipe()
os.dup2(pipe_writer, self.fd)
os.close(pipe_writer)
self._capture_thread = threading.Thread(target=self._capture)
self._capture_thread.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
os.close(self.fd)
self._capture_thread.join()
os.close(self._pipe_reader)
os.dup2(self._duped_fd, self.fd)
os.close(self._duped_fd)


def sdpa_kernel_from_debug_info(
config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel, sess_options: SessionOptions
):
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1"
captured_text = None
try:
with CaptureStdout() as captured:
session = create_session(config, sess_options, attention_kernel=attention_kernel)
input_dict = config.random_inputs()
session.infer(input_dict)
captured_text = captured.output.decode()
except Exception as e:
print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}")
finally:
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0"

if captured_text is not None:
m = re.search("SdpaKernel=(?P<kernel>[A-Z_]+)", captured_text)
if m is not None:
name = m.group("kernel")
kernel_names = {
"FLASH_ATTENTION": "ort:flash",
"EFFICIENT_ATTENTION": "ort:efficient",
"CUDNN_FLASH_ATTENTION": "ort:cudnn",
"MATH": "ort:math",
"TRT_FUSED_ATTENTION": "ort:trt_fmha",
"TRT_FLASH_ATTENTION": "ort:trt_flash",
"TRT_CROSS_ATTENTION": "ort:trt_cross",
"TRT_CAUSAL_ATTENTION": "ort:trt_causal",
}
return kernel_names[name]
else:
print("Failed to get sdpa kernel from debug info:", captured_text)

return None


def run_tflops_test(
csv_writer: csv.DictWriter,
args: argparse.Namespace,
Expand Down Expand Up @@ -809,7 +878,9 @@ def run_tflops_test(
backends = [SdpaKernel.DEFAULT]

configs = get_test_configs(args)
print("\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tkernel")
print(
"\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tsdpa_kernel\trequest_kernel"
)

for input_format in formats:
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
Expand All @@ -836,14 +907,13 @@ def run_tflops_test(
for attention_kernel in backends:
sess_options = SessionOptions()
sess_options.intra_op_num_threads = intra_op_num_threads
session = create_session(config, sess_options, attention_kernel=attention_kernel)

if use_gpu:
kernel = get_gpu_kernel_name(attention_kernel)
request_kernel = get_gpu_kernel_name(attention_kernel)
else:
kernel = get_cpu_kernel_name(config)
request_kernel = get_cpu_kernel_name(config)

if "math" in kernel:
if "math" in request_kernel:
# Skip large sequence length for Unfused kernel to avoid OOM.
if not enable_unfused:
if config.verbose:
Expand All @@ -856,13 +926,23 @@ def run_tflops_test(
print(f"skip input_format for {vars(config)}")
continue

if use_gpu:
actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options)
if actual_kernel is None:
print(f"Warning: skip {config} since kernel from debug info is None")
continue
else:
# CPU has no debug info for now.
actual_kernel = request_kernel

session = create_session(config, sess_options, attention_kernel=attention_kernel)
input_dict = config.random_inputs()

# warm up session
try:
_ = measure_latency(session, input_dict)
except Exception as e:
print(f"Failed to run {kernel=} for {config=}. Exception: {e}")
print(f"Failed to run {request_kernel=} for {config=}. Exception: {e}")
continue

latency_list = []
Expand Down Expand Up @@ -898,15 +978,16 @@ def run_tflops_test(
"intra_op_num_threads": intra_op_num_threads,
"average_latency": average_latency,
"tflops": speed,
"kernel": kernel,
"request_kernel": request_kernel,
"kernel": actual_kernel,
}
csv_writer.writerow(row)

speed = f"{speed:.2f}" if speed is not None else "NA"
print(
f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t"
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}"
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}"
)


Expand Down Expand Up @@ -979,7 +1060,7 @@ def run_torch_test(
print(
f"{input_format}\t{causal}\t{False}\t{batch_size}\t"
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}"
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}\t{backend_name}"
)
row = {
"use_gpu": use_gpu,
Expand All @@ -997,6 +1078,7 @@ def run_torch_test(
"intra_op_num_threads": torch.get_num_threads(),
"average_latency": torch_latency,
"tflops": speed,
"request_kernel": backend_name,
"kernel": backend_name,
}
csv_writer.writerow(row)
Expand Down Expand Up @@ -1030,6 +1112,7 @@ def run_tflops_tests(args):
"intra_op_num_threads",
"average_latency",
"tflops",
"request_kernel",
"kernel",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
Expand Down Expand Up @@ -1224,7 +1307,7 @@ def _parse_arguments():
"--repeats",
required=False,
type=int,
default=100,
default=0,
help="number of repeats for performance test",
)

Expand Down Expand Up @@ -1269,6 +1352,9 @@ def _parse_arguments():
args = _parse_arguments()
print(f"arguments:{args}")

if args.repeats == 0:
args.repeats = 10000 if args.use_gpu else 100

if args.use_gpu:
assert args.torch or not args.causal, "no causal cuda kernel in MHA op"
assert torch.cuda.is_available()
Expand Down

0 comments on commit 25d7a4f

Please sign in to comment.