Skip to content

Commit

Permalink
multi thread tests for MultiHeadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 21, 2024
1 parent f9fc075 commit 6f399a6
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 11 deletions.
49 changes: 46 additions & 3 deletions onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,49 @@ def shape_dict(self, input_format=None):
)
return shapes

def symbolic_shape_dict(self, input_format=None):
input_format = input_format or self.input_format
if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
# cross attention does not have past state
return {
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", self.num_heads, "sequence_length", self.head_size),
"value": ("batch_size", self.num_heads, "sequence_length", self.head_size),
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}

if self.use_kv_cache:
shapes = {
"past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
"past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
"present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
}
else:
shapes = {
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}

if input_format == InputFormats.QKV_BSN3H:
shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)})
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
shapes.update(
{
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size),
}
)
else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH
shapes.update(
{
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"value": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}
)
return shapes

def random_inputs(self, seed: int = 123):
device = self.device
dtype = self.dtype
Expand Down Expand Up @@ -215,7 +258,7 @@ def random_inputs(self, seed: int = 123):

def get_input_output_names(self):
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
return ["query", "key"], ["output"]
return ["query", "key", "value"], ["output"]

if self.input_format == InputFormats.QKV_BSN3H:
inputs, outputs = ["query"], ["output"]
Expand All @@ -235,7 +278,7 @@ def fill_optional_mha_inputs(input_names):
return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:]


def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False):
input_names, output_names = config.get_input_output_names()

float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT
Expand All @@ -252,7 +295,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
),
]

shape_dict = config.shape_dict()
shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict()
inputs = [
helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name]))
for input_name in input_names
Expand Down
Loading

0 comments on commit 6f399a6

Please sign in to comment.