Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 3, 2024
1 parent e4ac550 commit bb84371
Showing 1 changed file with 55 additions and 36 deletions.
91 changes: 55 additions & 36 deletions onnxruntime/test/python/transformers/test_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
import math
import unittest
from typing import Optional
from typing import Optional, Union

import torch
from benchmark_mha import InputFormats
Expand All @@ -17,7 +17,7 @@
from torch import Tensor

from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
from onnxruntime.transformers.io_binding_helper import CudaSession

ENABLE_DEBUG = False

Expand Down Expand Up @@ -616,7 +616,10 @@ def group_query_attention_reference(
attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value)

result = attn_output.transpose(1, 2).contiguous()
torch.cuda.synchronize()

if torch.cuda.is_available():
torch.cuda.synchronize()

return result


Expand Down Expand Up @@ -688,25 +691,42 @@ def infer(self):
)


def create_ort_session(
config: Union[SparseAttentionConfig, GroupQueryAttentionConfig], session_options=None, enable_cuda_graph=False
) -> CudaSession:
if isinstance(config, SparseAttentionConfig):
onnx_model_str = create_sparse_attention_onnx_model(config)
else:
onnx_model_str = create_group_query_attention_onnx_model(config)

if config.provider == "CUDAExecutionProvider":
device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
provider_options = CudaSession.get_cuda_provider_options(
device_id, enable_cuda_graph=enable_cuda_graph, stream=torch.cuda.current_stream().cuda_stream
)
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]

ort_session = InferenceSession(onnx_model_str, session_options, providers=providers)
# Note that CudaSession could work with both CUDA and CPU providers.
cuda_session = CudaSession(ort_session, config.device, enable_cuda_graph=enable_cuda_graph)
shape_dict = config.shape_dict()
cuda_session.allocate_buffers(shape_dict)

buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
for input_name, output_name in buffer_sharing.items():
cuda_session.set_buffer_sharing(input_name, output_name)

return cuda_session


class OrtGroupQueryAttention:
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""

def __init__(self, config: GroupQueryAttentionConfig):
cuda_provider_options = CudaSession.get_cuda_provider_options(
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
)
onnx_model_str = create_group_query_attention_onnx_model(config)
self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
self.gpu_binding_manager = GpuBindingManager(
ort_session=self.ort_session,
device=config.device,
stream=torch.cuda.current_stream().cuda_stream,
max_cuda_graphs=2,
)
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
self.gpu_binding = self.gpu_binding_manager.get_binding(
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
)
self.session = create_ort_session(config)

self.feed_dict = config.random_inputs()

if ENABLE_DEBUG and not config.is_packed_qkv:
Expand All @@ -726,28 +746,14 @@ def __init__(self, config: GroupQueryAttentionConfig):
print("seqlens_k (BSNH, GQA)", self.feed_dict["seqlens_k"])

def infer(self):
return self.gpu_binding.infer(self.feed_dict)
return self.session.infer(self.feed_dict)


class OrtSparseAttention:
"""A wrapper of ORT SparseAttention to test relevance and performance."""

def __init__(self, config: SparseAttentionConfig):
cuda_provider_options = CudaSession.get_cuda_provider_options(
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
)
onnx_model_str = create_sparse_attention_onnx_model(config)
self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
self.gpu_binding_manager = GpuBindingManager(
ort_session=self.ort_session,
device=config.device,
stream=torch.cuda.current_stream().cuda_stream,
max_cuda_graphs=2,
)
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
self.gpu_binding = self.gpu_binding_manager.get_binding(
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
)
self.session = create_ort_session(config)
self.feed_dict = config.random_inputs()

if ENABLE_DEBUG and not config.is_packed_qkv:
Expand All @@ -770,7 +776,7 @@ def __init__(self, config: SparseAttentionConfig):
print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"])

def infer(self):
return self.gpu_binding.infer(self.feed_dict)
return self.session.infer(self.feed_dict)


def get_provider_support_info(provider: str, use_kv_cache: bool):
Expand Down Expand Up @@ -817,6 +823,7 @@ def get_simple_test_case(provider: str, has_past_kv: bool):
local_blocks=2,
vert_stride=2,
softmax_scale=0.0,
provider=provider,
device=device,
dtype=dtype,
is_packed_qkv=packed_qkv,
Expand All @@ -834,7 +841,9 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot
batch_sizes = [1, 2, 3]
sequence_lengths = [1, 64, 127, 128, 192, 256]
heads = [4, 8, 16]
head_sizes = [128, 256]

# SparseAttention CUDA kernel only supports head size 128
head_sizes = [128] if provider == "CUDAExecutionProvider" else [128, 256]

if comprehensive:
for batch_size in batch_sizes:
Expand Down Expand Up @@ -865,6 +874,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot
local_blocks=2,
vert_stride=2,
softmax_scale=1.8 / (128**0.5),
provider=provider,
device=device,
dtype=dtype,
is_packed_qkv=packed_qkv,
Expand Down Expand Up @@ -903,6 +913,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot
local_blocks=2,
vert_stride=2,
softmax_scale=1.8 / (128**0.5),
provider=provider,
device=device,
dtype=dtype,
is_packed_qkv=packed_qkv,
Expand Down Expand Up @@ -963,6 +974,10 @@ def run_one_relevance_test(self, config: SparseAttentionConfig):
obj = TorchGroupQueryAttention(gqa_config)
expected_out = obj.infer()
else:
if config.dtype == torch.bfloat16:
# Skip test since create_group_query_attention_onnx_model does not support bfloat16 right now.
return

# Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only).
gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False)
obj = OrtGroupQueryAttention(gqa_config)
Expand Down Expand Up @@ -1070,6 +1085,8 @@ def run_relevance_no_past_128k(self, sm: int, device):
local_blocks=2048, # use dense to compare with GQA
vert_stride=8,
softmax_scale=None,
provider="CUDAExecutionProvider",
dtype=torch.float16,
device=device,
is_packed_qkv=packed_qkv,
)
Expand All @@ -1096,6 +1113,8 @@ def run_relevance_past_128k(self, sm: int, device):
local_blocks=2048, # use dense to compare with GQA
vert_stride=8,
softmax_scale=None,
provider="CUDAExecutionProvider",
dtype=torch.float16,
device=device,
is_packed_qkv=packed_qkv,
)
Expand Down

0 comments on commit bb84371

Please sign in to comment.