From 6209dcf5d89a680f5fa5d28a6df54a9e02339ec9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 23 Jul 2024 00:25:35 +0000 Subject: [PATCH 1/5] update benchmark_mha to compare with PyTorch SDPA --- .../test/python/transformers/benchmark_mha.py | 685 ++++++++++++++---- .../test/python/transformers/benchmark_mha.sh | 40 +- .../test/python/transformers/test_mha.py | 46 +- 3 files changed, 570 insertions(+), 201 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 111c417479d20..4030b9369dcf0 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -4,21 +4,30 @@ # -------------------------------------------------------------------------- """ -Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: +Benchmark performance of MultiHeadAttention with ORT or PyTorch. For example, run the the following in Linux: sh benchmark_mha.sh """ +import argparse +import csv import math import os import platform import statistics import time -from typing import List, Optional +from contextlib import nullcontext +from datetime import datetime +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple import torch +import torch.utils.benchmark as benchmark from onnx import TensorProto, helper +from packaging.version import Version +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.functional import scaled_dot_product_attention -from onnxruntime import InferenceSession, get_available_providers +from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession @@ -43,6 +52,20 @@ def get_name_list() -> List[str]: return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"] +class SdpaKernel(IntEnum): + """Bit flags for sdpa_kernel CUDA provider option""" + + DEFAULT = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + TRT_FUSED_ATTENTION = 4 + CUDNN_FLASH_ATTENTION = 8 + MATH = 16 + TRT_FLASH_ATTENTION = 32 + TRT_CROSS_ATTENTION = 64 + TRT_CAUSAL_ATTENTION = 128 + + class MultiHeadAttentionConfig: def __init__( self, @@ -62,6 +85,7 @@ def __init__( use_kv_cache: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, + verbose: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -100,6 +124,7 @@ def __init__( self.input_format = input_format self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H + self.verbose = verbose def __repr__(self): return ( @@ -114,89 +139,93 @@ def __repr__(self): ) def shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), } if self.use_kv_cache: + assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" shapes = { + **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } - else: - shapes = { - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - ) return shapes def symbolic_shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), "key": ("batch_size", self.num_heads, "sequence_length", self.head_size), "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), } if self.use_kv_cache: + assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" shapes = { + **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } - else: - shapes = { - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - ) return shapes def random_inputs(self, seed: int = 123): @@ -215,44 +244,42 @@ def random_inputs(self, seed: int = 123): k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) - if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - return { + if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + feeds = { "query": q.reshape(shape_dict["query"]), - "key": k_bnsh.contiguous(), - "value": v_bnsh.contiguous(), + "key": k.reshape(shape_dict["key"]), + "value": v.reshape(shape_dict["value"]), } - - feeds = {} - if self.use_kv_cache: - feeds.update( - { - "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - } - ) - - if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: - feeds.update( - { - "query": q.reshape(shape_dict["query"]), - "key": k.reshape(shape_dict["key"]), - "value": v.reshape(shape_dict["value"]), - } - ) elif self.input_format == InputFormats.QKV_BSN3H: query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous() + feeds = { + "query": torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous(), + } elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = q.reshape(shape_dict["query"]) - feeds["key"] = torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous() + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous(), + } + else: + assert self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": k_bnsh.contiguous(), + "value": v_bnsh.contiguous(), + } + + if self.use_kv_cache: + feeds = { + **feeds, + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( + mean=0, std=0.1 + ), + } return feeds @@ -318,19 +345,32 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use return model.SerializeToString() -def create_session( +def create_ort_session( config: MultiHeadAttentionConfig, + session_options=None, + attention_kernel=SdpaKernel.DEFAULT, + use_symbolic_shape: bool = True, ) -> CudaSession: - onnx_model_str = create_multi_head_attention_onnx_model(config) + if config.verbose: + print(f"create session for {vars(config)}") + onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=use_symbolic_shape) if config.provider == "CUDAExecutionProvider": device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) + provider_options["sdpa_kernel"] = int(attention_kernel) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + return ort_session + + +def create_session( + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT +) -> CudaSession: + ort_session = create_ort_session(config, session_options, attention_kernel, use_symbolic_shape=False) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -340,11 +380,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__( - self, - config: MultiHeadAttentionConfig, - ): - self.ort_session = create_session(config) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None): + self.ort_session = create_session(config, session_options) self.feed_dict = config.random_inputs() def infer(self): @@ -363,53 +400,112 @@ def flops(batch, sequence_length, head_size, num_heads, causal): def tflops_per_second(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - - -def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str: - # This classification is for Nvidia GPU of Compute Capability 8.* like A100. - # Note that some kernel might not exist in older or newer GPUs. - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - if config.input_format == InputFormats.QKV_BSN3H: - min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") - min_length = int(min_seq_len) if min_seq_len is not None else 513 - if config.sequence_length >= min_length: - return "Flash" - else: - return "Flash" + try: + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + except ZeroDivisionError: + return None + + +def get_gpu_kernel_name(config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel) -> str: + if attention_kernel == SdpaKernel.DEFAULT: + # This classification is for Nvidia GPU of Compute Capability 8.* like A100. + # Note that some kernel might not exist in older or newer GPUs. + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + if config.input_format == InputFormats.QKV_BSN3H: + min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") + min_length = int(min_seq_len) if min_seq_len is not None else 513 + if config.sequence_length >= min_length: + return "ort:flash" + else: + return "ort:flash" + + if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( + os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" + and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") + ): + return "ort:trt" + + if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": + return "ort:efficient" + + return "ort:math" + + kernel_names = { + SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", + SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", + SdpaKernel.MATH: "ort:math", + } + assert attention_kernel in kernel_names + return kernel_names[attention_kernel] - if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( - os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" - and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") - ): - return "TRT" - if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": - return "MemEff" +def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: + # CPU Flash Attention does not support causal and kv cache etc. + if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0): + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + return "cpu:ort:flash" - return "Unfused" + return "cpu:ort:math" -def get_cpu_kernel_name() -> str: - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - return "CPU:Flash" - return "CPU:Unfused" +# ------------------------------------------------------------------ +# Functions for benchmarking PyTorch SDPA +# ------------------------------------------------------------------ +def benchmark_torch_function(func: Callable, *args, **kwargs) -> float: + warmup = 5 + repeats = 100 + for _ in range(warmup): + func(*args, **kwargs) + timer = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) -def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100): - if use_gpu: - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] - provider = "CUDAExecutionProvider" - print(f"enable_cuda_graph={enable_cuda_graph}") - else: - device_id = 0 - device = torch.device("cpu") - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - enable_cuda_graph = False - provider = "CPUExecutionProvider" + return timer.timeit(number=repeats).median + +def run_torch_sdpa( + batch_size: int, + q_seq_len: int, + kv_seq_len: int, + num_heads: int, + head_size: int, + causal: bool, + device, + dtype, + has_mask: bool = False, + mask_dim: int = 2, + mask_dtype=torch.bool, + backend: Optional[int] = None, +): + q_shape = (batch_size, num_heads, q_seq_len, head_size) + kv_shape = (batch_size, num_heads, kv_seq_len, head_size) + q = torch.randn(q_shape, device=device, dtype=dtype) + k = torch.randn(kv_shape, device=device, dtype=dtype) + v = torch.randn(kv_shape, device=device, dtype=dtype) + + attn_mask = None + if has_mask: + mask_shape = (batch_size, num_heads, q_seq_len, kv_seq_len) if mask_dim == 4 else (q_seq_len, kv_seq_len) + attn_mask = torch.ones(mask_shape, dtype=mask_dtype, device=device) + + context = sdpa_kernel(backend) if backend is not None else nullcontext() + + with context: + average_latency = benchmark_torch_function( + scaled_dot_product_attention, + q, + k, + v, + is_causal=causal, + attn_mask=attn_mask, + ) + return average_latency + + +def get_test_configs(use_gpu: bool = True): if use_gpu: # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) configs = [ @@ -450,12 +546,57 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea ] else: configs = [ + # TNLGv4 (1, 128, 0, 32, 128, True), (1, 256, 0, 32, 128, True), (1, 512, 0, 32, 128, True), (1, 1024, 0, 32, 128, True), (1, 2048, 0, 32, 128, True), + # bert-base + (1, 128, 0, 12, 64, True), + (1, 384, 0, 12, 64, True), + (1, 512, 0, 12, 64, True), + (4, 128, 0, 12, 64, True), + (4, 384, 0, 12, 64, True), + (4, 512, 0, 12, 64, True), + # bert-large + (1, 128, 0, 16, 64, True), + (1, 384, 0, 16, 64, True), + (1, 512, 0, 16, 64, True), + (4, 128, 0, 16, 64, True), + (4, 384, 0, 16, 64, True), + (4, 512, 0, 16, 64, True), ] + return configs + + +def run_tflops_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + enable_cuda_graph: bool = False, + causal: bool = False, + has_past: bool = False, + intra_op_num_threads: int = 0, + repeats: int = 100, +): + print(f"run_tflops_test: causal={causal}") + + if use_gpu: + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] + provider = "CUDAExecutionProvider" + print(f"enable_cuda_graph={enable_cuda_graph}") + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + else: + device_id = 0 + device = torch.device("cpu") + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + enable_cuda_graph = False + provider = "CPUExecutionProvider" + backends = [SdpaKernel.DEFAULT] + + configs = get_test_configs(use_gpu) # List of environment variables to enable/disable attention kernels print("Environment Variables:") @@ -468,13 +609,17 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea "ORT_DISABLE_FUSED_CROSS_ATTENTION", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", ] + + env_list = "" for name in env_names: value = os.getenv(name) if value is not None: print(f"{name}={value}") + if env_list: + env_list += "," + env_list += f"{name}={value}" - print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") - causal = False + print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: @@ -496,21 +641,27 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea share_past_present_buffer=False, input_format=input_format, ) - - session = create_session(config) + for attention_kernel in backends: + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, attention_kernel=attention_kernel) if use_gpu: - kernel = get_gpu_kernel_name(config) + kernel = get_gpu_kernel_name(config, attention_kernel) else: - kernel = get_cpu_kernel_name() + kernel = get_cpu_kernel_name(config) - if kernel == "Unfused": + if "math" in kernel: # Skip large sequence length for Unfused kernel to avoid OOM. if not enable_unfused: + if config.verbose: + print(f"skip unfused kernel for {vars(config)}") continue # Unfused kernel does not support packed QKV or packed KV formats. if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + if config.verbose: + print(f"skip input_format for {vars(config)}") continue input_dict = config.random_inputs() @@ -526,17 +677,174 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea del session + format_str = InputFormats.input_format_str(input_format) + # compute TFLOPS per second - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = "NA" + if not has_past: + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "format": format_str, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": intra_op_num_threads, + "average_latency": average_latency, + "tflops": speed, + "kernel": kernel, + "environment_variables": env_list, + } + csv_writer.writerow(row) - format = InputFormats.input_format_str(input_format) + speed = f"{speed:.2f}" if speed is not None else "NA" print( - f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" ) +def run_torch_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + causal: bool = False, +): + configs = get_test_configs(use_gpu) + + if use_gpu: + if not torch.cuda.is_available(): + return + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + backends = [ + None, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + SDPBackend.MATH, + ] + else: + device_id = 0 + device = torch.device("cpu") + dtype = torch.float32 + backends = [None] + + backend_names = { + SDPBackend.FLASH_ATTENTION: "torch:flash", + SDPBackend.EFFICIENT_ATTENTION: "torch:efficient", + SDPBackend.CUDNN_ATTENTION: "torch:cudnn", + SDPBackend.MATH: "torch:math", + None: "torch:default", + } + + # Test PyTorch latency + for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + for backend in backends: + if backend == SDPBackend.MATH and not enable_unfused: + continue + if backend == SDPBackend.FLASH_ATTENTION and platform.system() != "Linux": + continue + + backend_name = backend_names[backend] + try: + with torch.no_grad(): + torch_latency = run_torch_sdpa( + batch_size, + sequence_length, + sequence_length, + num_heads, + head_size, + causal, + has_mask=False, + mask_dim=2, + mask_dtype=torch.bool, + device=device, + dtype=dtype, + backend=backend, + ) + except RuntimeError: + continue + + speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + kernel = ("Torch" if use_gpu else "Torch:cpu") + (f":{backend_name}" if backend is not None else "") + input_format = "Q,K,V" + print( + f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + ) + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": False, + "format": input_format, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": torch.get_num_threads(), + "average_latency": torch_latency, + "tflops": speed, + "kernel": kernel, + "environment_variables": "", + } + csv_writer.writerow(row) + + +def run_tflops_tests(args): + features = "gpu" if args.use_gpu else "cpu" + if args.causal: + features += "_causal" + if args.has_past: + features += "_past" + csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + features, + "torch" if args.torch else "ort", + datetime.now().strftime("%Y%m%d-%H%M%S"), + ) + with open(csv_filename, mode="a", newline="") as csv_file: + column_names = [ + "use_gpu", + "enable_cuda_graph", + "format", + "causal", + "batch_size", + "sequence_length", + "past_sequence_length", + "num_heads", + "head_size", + "intra_op_num_threads", + "average_latency", + "tflops", + "kernel", + "environment_variables", + ] + csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) + csv_writer.writeheader() + + if args.torch: + assert Version(torch.__version__) >= Version("2.3.0") + assert args.has_past is False + run_torch_test(csv_writer, args.use_gpu, args.causal) + else: + run_tflops_test( + csv_writer, + use_gpu=args.use_gpu, + enable_cuda_graph=args.use_cuda_graph, + causal=args.causal, + has_past=args.has_past, + intra_op_num_threads=args.intra_op_num_threads, + ) + + def plot_prompt_performance( sm: int, model_name: str, @@ -591,13 +899,14 @@ def benchmark( sequence_length=sequence_length, num_heads=num_heads, head_size=head_size, - causal=True, + causal=False, past_sequence_length=0, kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, max_cache_sequence_length=max_seq_len, provider="CUDAExecutionProvider", enable_cuda_graph=False, device=device, + dtype=torch.float16, use_kv_cache=False, input_format=InputFormats.convert(input_format), ) @@ -609,14 +918,14 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_performance_test(sm: int): +def run_bert_performance_test(sm: int): """ Run performance tests for prompt and token generation. """ configures = [ - (1, 32, 128, 8192, "TNLGv4"), - (4, 32, 128, 8192, "TNLGv4"), + # (1, 32, 128, 8192, "TNLGv4"), + # (4, 32, 128, 8192, "TNLGv4"), (1, 12, 64, 1024, "BertBase"), (16, 12, 64, 1024, "BertBase"), (1, 16, 64, 1024, "BertLarge"), @@ -634,18 +943,86 @@ def run_performance_test(sm: int): ) +def _parse_arguments(): + parser = argparse.ArgumentParser(description="Benchmark MultiHeadAttention for ONNX Runtime and PyTorch.") + + parser.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for inference.", + ) + parser.set_defaults(use_gpu=False) + + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use cuda graph in onnxruntime.", + ) + parser.set_defaults(use_cuda_graph=False) + + parser.add_argument( + "--intra_op_num_threads", + required=False, + type=int, + choices=[0, 1, 2, 4, 8, 16], + default=0, + help="intra_op_num_threads for onnxruntime. ", + ) + + parser.add_argument( + "--has_past", + required=False, + action="store_true", + help="whether past_sequence_length > 0", + ) + parser.set_defaults(has_past=False) + + parser.add_argument( + "--causal", + required=False, + action="store_true", + help="test unidirectional", + ) + parser.set_defaults(causal=False) + + parser.add_argument( + "--torch", + required=False, + action="store_true", + help="test pytorch instead of onnxruntime", + ) + parser.set_defaults(torch=False) + + args = parser.parse_args() + + return args + + if __name__ == "__main__": - if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): - # Test CUDA provider + args = _parse_arguments() + print(f"arguments:{args}") + + if args.has_past: + assert args.causal, "--has_past need --causal specified" + + if args.use_gpu: + assert args.torch or not args.causal, "no causl cuda kernel in MHA op" + assert torch.cuda.is_available() + if not args.torch: + assert "CUDAExecutionProvider" in get_available_providers() + + if args.use_gpu and not args.torch: major, minor = torch.cuda.get_device_capability() sm = major * 10 + minor + if major != 8: + print(f"Warning: ORT kernel name logic in this script is for sm=8x, current device is {major}{minor}.") + if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_performance_test(sm) - - run_tflops_test(use_gpu=True, enable_cuda_graph=True) + run_bert_performance_test(sm) - # Test CPU provider - run_tflops_test(use_gpu=False, enable_cuda_graph=False) + run_tflops_tests(args) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 7b21cf1cc1e08..0080ee111d900 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -1,14 +1,32 @@ -echo "flash attention v2" -ORT_DISABLE_FLASH_ATTENTION=0 ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV=0 python benchmark_mha.py | tee result.txt +echo "Benchmark performance on GPU:" -echo "===" -echo "TensorRT attention kernels - cross attention (when kv_seq_len <= 128) or fused attention (when seq_len <= 384) or flash attention (seq_len > 384)" -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +export CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch -echo "===" -echo "Memory Efficient attention" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv -echo "===" -echo "Unfused Attention (some configurations might fail)" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +echo "Benchmark performance on CPU with number of threads:" +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +# Pytorch SDPA does not support causal attention with past state, we only test ORT here. +python benchmark_mha.py --causal --has_past + +cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index ff473cc2ced92..0fcbd889847e9 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -10,36 +10,15 @@ import concurrent.futures import itertools import unittest -from enum import IntEnum from typing import Dict, List, Optional import numpy import torch -from benchmark_mha import ( - InputFormats, - MultiHeadAttentionConfig, - OrtMultiHeadAttention, - create_multi_head_attention_onnx_model, -) +from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention, SdpaKernel, create_ort_session from einops import rearrange from parameterized import parameterized import onnxruntime -from onnxruntime import InferenceSession - - -class SdpaKernel(IntEnum): - """Bit flags for sdpa_kernel CUDA provider option""" - - DEFAULT = 0 - FLASH_ATTENTION = 1 - EFFICIENT_ATTENTION = 2 - TRT_FUSED_ATTENTION = 4 - CUDNN_FLASH_ATTENTION = 8 - MATH = 16 - TRT_FLASH_ATTENTION = 32 - TRT_CROSS_ATTENTION = 64 - TRT_CAUSAL_ATTENTION = 128 def attention_reference( @@ -466,7 +445,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - sdpa_kernel: int = SdpaKernel.DEFAULT, + attention_kernel: int = SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -476,21 +455,14 @@ def parity_check_mha_multi_threading( if config.causal and config.provider == "CUDAExecutionProvider": return None # Some kernel does not support certain input format. - if sdpa_kernel not in [ + if attention_kernel not in [ SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - if verbose: - print(f"create a shared session with {vars(config)}") - onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True) - if config.provider == "CUDAExecutionProvider": - provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)} - providers = [(config.provider, provider_options), "CPUExecutionProvider"] - else: - providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -613,7 +585,7 @@ def test_mha_cuda(self, config): def test_mha_cpu(self, config): parity_check_mha(config) - def run_mha_cuda_multi_threading(self, spda_kernel): + def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): test_inputs = [] for config in configs: @@ -626,8 +598,10 @@ def run_mha_cuda_multi_threading(self, spda_kernel): config.input_format = old_format test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs}) - exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs)) - assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}" + exception = parity_check_mha_multi_threading( + test_inputs, attention_kernel=attention_kernel, max_threads=len(configs) + ) + assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) From e4caf29dc3049697c3abdd5f1e38c3950d8fa32b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 23 Jul 2024 00:46:31 +0000 Subject: [PATCH 2/5] add comment --- onnxruntime/test/python/transformers/benchmark_mha.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 0080ee111d900..7027c556a2ab6 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -1,4 +1,11 @@ -echo "Benchmark performance on GPU:" +#!/bin/sh + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" export CUDA_VISIBLE_DEVICES=0 python benchmark_mha.py --use_gpu From f3f95f2e5cafab5b78a02facfdea227b0a8c7501 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 23 Jul 2024 00:48:57 +0000 Subject: [PATCH 3/5] remove unused variable --- onnxruntime/test/python/transformers/benchmark_mha.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 4030b9369dcf0..2c79b25198332 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -732,7 +732,6 @@ def run_torch_test( SDPBackend.MATH, ] else: - device_id = 0 device = torch.device("cpu") dtype = torch.float32 backends = [None] From ecff18ba2411d2dc90fdf4c341cedd380b09b73f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jul 2024 22:21:21 -0700 Subject: [PATCH 4/5] Add Windows batch file --- .../python/transformers/benchmark_mha.cmd | 47 ++++++++ .../test/python/transformers/benchmark_mha.py | 110 ++++++------------ .../test/python/transformers/benchmark_mha.sh | 1 + 3 files changed, 85 insertions(+), 73 deletions(-) create mode 100644 onnxruntime/test/python/transformers/benchmark_mha.cmd diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd new file mode 100644 index 0000000000000..f1443f81b4ab5 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -0,0 +1,47 @@ +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" + +set CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +set MKL_DYNAMIC=FALSE +set OMP_NUM_THREADS=1 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=2 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=4 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=8 +python benchmark_mha.py --torch + +set MKL_DYNAMIC= +set OMP_NUM_THREADS= + +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py + +set ORT_DISABLE_FLASH_ATTENTION=1 +python benchmark_mha.py +set ORT_DISABLE_FLASH_ATTENTION= + +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +python benchmark_mha.py --causal --has_past + +type benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 2c79b25198332..3404da218250c 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -4,8 +4,13 @@ # -------------------------------------------------------------------------- """ -Benchmark performance of MultiHeadAttention with ORT or PyTorch. For example, run the the following in Linux: -sh benchmark_mha.sh +Benchmark performance of MultiHeadAttention with ORT or PyTorch. + +In Linux, run the the following: + sh benchmark_mha.sh + +In Windows, run the the following: + benchmark_mha.cmd """ import argparse @@ -406,31 +411,9 @@ def tflops_per_second(flop, time): return None -def get_gpu_kernel_name(config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel) -> str: - if attention_kernel == SdpaKernel.DEFAULT: - # This classification is for Nvidia GPU of Compute Capability 8.* like A100. - # Note that some kernel might not exist in older or newer GPUs. - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - if config.input_format == InputFormats.QKV_BSN3H: - min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") - min_length = int(min_seq_len) if min_seq_len is not None else 513 - if config.sequence_length >= min_length: - return "ort:flash" - else: - return "ort:flash" - - if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( - os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" - and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") - ): - return "ort:trt" - - if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": - return "ort:efficient" - - return "ort:math" - +def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: kernel_names = { + SdpaKernel.DEFAULT: "ort:default", SdpaKernel.FLASH_ATTENTION: "ort:flash", SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", @@ -444,9 +427,9 @@ def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: # CPU Flash Attention does not support causal and kv cache etc. if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0): if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - return "cpu:ort:flash" + return "ort:flash" - return "cpu:ort:math" + return "ort:math" # ------------------------------------------------------------------ @@ -570,6 +553,13 @@ def get_test_configs(use_gpu: bool = True): return configs +def get_compute_capability(): + assert torch.cuda.is_available() + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm + + def run_tflops_test( csv_writer: csv.DictWriter, use_gpu: bool = True, @@ -586,8 +576,12 @@ def run_tflops_test( device = torch.device("cuda", device_id) formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] provider = "CUDAExecutionProvider" - print(f"enable_cuda_graph={enable_cuda_graph}") - backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + # flash attention is available for sm >= 80 + sm = get_compute_capability() + if sm >= 80: + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + else: + backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] else: device_id = 0 device = torch.device("cpu") @@ -598,27 +592,6 @@ def run_tflops_test( configs = get_test_configs(use_gpu) - # List of environment variables to enable/disable attention kernels - print("Environment Variables:") - env_names = [ - "ORT_DISABLE_FLASH_ATTENTION", - "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", - "ORT_DISABLE_FUSED_ATTENTION", - "ORT_DISABLE_TRT_FLASH_ATTENTION", - "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", - "ORT_DISABLE_FUSED_CROSS_ATTENTION", - "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", - ] - - env_list = "" - for name in env_names: - value = os.getenv(name) - if value is not None: - print(f"{name}={value}") - if env_list: - env_list += "," - env_list += f"{name}={value}" - print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: @@ -647,7 +620,7 @@ def run_tflops_test( session = create_session(config, sess_options, attention_kernel=attention_kernel) if use_gpu: - kernel = get_gpu_kernel_name(config, attention_kernel) + kernel = get_gpu_kernel_name(attention_kernel) else: kernel = get_cpu_kernel_name(config) @@ -680,8 +653,8 @@ def run_tflops_test( format_str = InputFormats.input_format_str(input_format) # compute TFLOPS per second - speed = "NA" - if not has_past: + speed = None + if past_sequence_length == 0: speed = tflops_per_second( flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency ) @@ -700,7 +673,6 @@ def run_tflops_test( "average_latency": average_latency, "tflops": speed, "kernel": kernel, - "environment_variables": env_list, } csv_writer.writerow(row) @@ -773,11 +745,10 @@ def run_torch_test( continue speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) - kernel = ("Torch" if use_gpu else "Torch:cpu") + (f":{backend_name}" if backend is not None else "") input_format = "Q,K,V" print( f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" - f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{backend_name}" ) row = { "use_gpu": use_gpu, @@ -792,8 +763,7 @@ def run_torch_test( "intra_op_num_threads": torch.get_num_threads(), "average_latency": torch_latency, "tflops": speed, - "kernel": kernel, - "environment_variables": "", + "kernel": backend_name, } csv_writer.writerow(row) @@ -824,14 +794,11 @@ def run_tflops_tests(args): "average_latency", "tflops", "kernel", - "environment_variables", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) csv_writer.writeheader() if args.torch: - assert Version(torch.__version__) >= Version("2.3.0") - assert args.has_past is False run_torch_test(csv_writer, args.use_gpu, args.causal) else: run_tflops_test( @@ -845,7 +812,6 @@ def run_tflops_tests(args): def plot_prompt_performance( - sm: int, model_name: str, batch_size: int, num_heads: int, @@ -865,6 +831,7 @@ def plot_prompt_performance( "styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)], } + sm = get_compute_capability() configs = [ triton.testing.Benchmark( x_names=["sequence_length"], @@ -917,7 +884,7 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_bert_performance_test(sm: int): +def run_bert_performance_test(): """ Run performance tests for prompt and token generation. @@ -933,7 +900,6 @@ def run_bert_performance_test(sm: int): for batch_size, num_heads, head_size, max_seq_len, model_name in configures: plot_prompt_performance( - sm=sm, batch_size=batch_size, num_heads=num_heads, head_size=head_size, @@ -1007,21 +973,19 @@ def _parse_arguments(): assert args.causal, "--has_past need --causal specified" if args.use_gpu: - assert args.torch or not args.causal, "no causl cuda kernel in MHA op" + assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available() if not args.torch: assert "CUDAExecutionProvider" in get_available_providers() - if args.use_gpu and not args.torch: - major, minor = torch.cuda.get_device_capability() - sm = major * 10 + minor - - if major != 8: - print(f"Warning: ORT kernel name logic in this script is for sm=8x, current device is {major}{minor}.") + if args.torch: + assert Version(torch.__version__) >= Version("2.3.0") + assert args.has_past is False + if args.use_gpu and not args.torch: if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_bert_performance_test(sm) + run_bert_performance_test() run_tflops_tests(args) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 7027c556a2ab6..613543d0172dd 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -28,6 +28,7 @@ python benchmark_mha.py --intra_op_num_threads 8 echo "Benchmark performance on CPU with default threads settings:" python benchmark_mha.py +ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py python benchmark_mha.py --torch python benchmark_mha.py --causal From 85f9a139382aacca064d580021571660455f17f7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 23 Jul 2024 07:09:14 +0000 Subject: [PATCH 5/5] minor change --- onnxruntime/test/python/transformers/benchmark_mha.cmd | 10 +++++----- onnxruntime/test/python/transformers/benchmark_mha.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd index f1443f81b4ab5..0a6d0c37b4a35 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.cmd +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -24,19 +24,15 @@ python benchmark_mha.py --torch set MKL_DYNAMIC= set OMP_NUM_THREADS= +set ORT_DISABLE_FLASH_ATTENTION=0 python benchmark_mha.py --intra_op_num_threads 1 python benchmark_mha.py --intra_op_num_threads 2 python benchmark_mha.py --intra_op_num_threads 4 python benchmark_mha.py --intra_op_num_threads 8 - echo "Benchmark performance on CPU with default threads settings:" python benchmark_mha.py -set ORT_DISABLE_FLASH_ATTENTION=1 -python benchmark_mha.py -set ORT_DISABLE_FLASH_ATTENTION= - python benchmark_mha.py --torch python benchmark_mha.py --causal @@ -44,4 +40,8 @@ python benchmark_mha.py --torch --causal python benchmark_mha.py --causal --has_past +set ORT_DISABLE_FLASH_ATTENTION=1 +python benchmark_mha.py +set ORT_DISABLE_FLASH_ATTENTION= + type benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 3404da218250c..715a92431e6bf 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -534,7 +534,7 @@ def get_test_configs(use_gpu: bool = True): (1, 256, 0, 32, 128, True), (1, 512, 0, 32, 128, True), (1, 1024, 0, 32, 128, True), - (1, 2048, 0, 32, 128, True), + # (1, 2048, 0, 32, 128, True), # bert-base (1, 128, 0, 12, 64, True), (1, 384, 0, 12, 64, True),