From bb84371d2e98943c6cf4e3661bc1f9c346f4b942 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 3 Jul 2024 00:34:10 -0700 Subject: [PATCH] update test --- .../transformers/test_sparse_attention.py | 91 +++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 64877fb257e20..eb892ac91a7f5 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -8,7 +8,7 @@ """ import math import unittest -from typing import Optional +from typing import Optional, Union import torch from benchmark_mha import InputFormats @@ -17,7 +17,7 @@ from torch import Tensor from onnxruntime import InferenceSession, SessionOptions, get_available_providers -from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager +from onnxruntime.transformers.io_binding_helper import CudaSession ENABLE_DEBUG = False @@ -616,7 +616,10 @@ def group_query_attention_reference( attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) result = attn_output.transpose(1, 2).contiguous() - torch.cuda.synchronize() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + return result @@ -688,25 +691,42 @@ def infer(self): ) +def create_ort_session( + config: Union[SparseAttentionConfig, GroupQueryAttentionConfig], session_options=None, enable_cuda_graph=False +) -> CudaSession: + if isinstance(config, SparseAttentionConfig): + onnx_model_str = create_sparse_attention_onnx_model(config) + else: + onnx_model_str = create_group_query_attention_onnx_model(config) + + if config.provider == "CUDAExecutionProvider": + device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index + provider_options = CudaSession.get_cuda_provider_options( + device_id, enable_cuda_graph=enable_cuda_graph, stream=torch.cuda.current_stream().cuda_stream + ) + providers = [(config.provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + # Note that CudaSession could work with both CUDA and CPU providers. + cuda_session = CudaSession(ort_session, config.device, enable_cuda_graph=enable_cuda_graph) + shape_dict = config.shape_dict() + cuda_session.allocate_buffers(shape_dict) + + buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} + for input_name, output_name in buffer_sharing.items(): + cuda_session.set_buffer_sharing(input_name, output_name) + + return cuda_session + + class OrtGroupQueryAttention: """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" def __init__(self, config: GroupQueryAttentionConfig): - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_group_query_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=config.device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) + self.session = create_ort_session(config) + self.feed_dict = config.random_inputs() if ENABLE_DEBUG and not config.is_packed_qkv: @@ -726,28 +746,14 @@ def __init__(self, config: GroupQueryAttentionConfig): print("seqlens_k (BSNH, GQA)", self.feed_dict["seqlens_k"]) def infer(self): - return self.gpu_binding.infer(self.feed_dict) + return self.session.infer(self.feed_dict) class OrtSparseAttention: """A wrapper of ORT SparseAttention to test relevance and performance.""" def __init__(self, config: SparseAttentionConfig): - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_sparse_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=config.device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) + self.session = create_ort_session(config) self.feed_dict = config.random_inputs() if ENABLE_DEBUG and not config.is_packed_qkv: @@ -770,7 +776,7 @@ def __init__(self, config: SparseAttentionConfig): print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"]) def infer(self): - return self.gpu_binding.infer(self.feed_dict) + return self.session.infer(self.feed_dict) def get_provider_support_info(provider: str, use_kv_cache: bool): @@ -817,6 +823,7 @@ def get_simple_test_case(provider: str, has_past_kv: bool): local_blocks=2, vert_stride=2, softmax_scale=0.0, + provider=provider, device=device, dtype=dtype, is_packed_qkv=packed_qkv, @@ -834,7 +841,9 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot batch_sizes = [1, 2, 3] sequence_lengths = [1, 64, 127, 128, 192, 256] heads = [4, 8, 16] - head_sizes = [128, 256] + + # SparseAttention CUDA kernel only supports head size 128 + head_sizes = [128] if provider == "CUDAExecutionProvider" else [128, 256] if comprehensive: for batch_size in batch_sizes: @@ -865,6 +874,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot local_blocks=2, vert_stride=2, softmax_scale=1.8 / (128**0.5), + provider=provider, device=device, dtype=dtype, is_packed_qkv=packed_qkv, @@ -903,6 +913,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot local_blocks=2, vert_stride=2, softmax_scale=1.8 / (128**0.5), + provider=provider, device=device, dtype=dtype, is_packed_qkv=packed_qkv, @@ -963,6 +974,10 @@ def run_one_relevance_test(self, config: SparseAttentionConfig): obj = TorchGroupQueryAttention(gqa_config) expected_out = obj.infer() else: + if config.dtype == torch.bfloat16: + # Skip test since create_group_query_attention_onnx_model does not support bfloat16 right now. + return + # Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only). gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False) obj = OrtGroupQueryAttention(gqa_config) @@ -1070,6 +1085,8 @@ def run_relevance_no_past_128k(self, sm: int, device): local_blocks=2048, # use dense to compare with GQA vert_stride=8, softmax_scale=None, + provider="CUDAExecutionProvider", + dtype=torch.float16, device=device, is_packed_qkv=packed_qkv, ) @@ -1096,6 +1113,8 @@ def run_relevance_past_128k(self, sm: int, device): local_blocks=2048, # use dense to compare with GQA vert_stride=8, softmax_scale=None, + provider="CUDAExecutionProvider", + dtype=torch.float16, device=device, is_packed_qkv=packed_qkv, )