Skip to content

Commit

Permalink
[CUDA] Fix MHA mask (#21655)
Browse files Browse the repository at this point in the history
### Description
Fix a check of mask type introduced by me in a recent commit. Add tests.
  • Loading branch information
tianleiwu authored and prathikr committed Aug 12, 2024
1 parent b67af53 commit ac38c44
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 74 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
bool use_fused_cross_attention =
!use_flash_attention &&
!disable_fused_cross_attention_ &&
Expand Down Expand Up @@ -213,7 +215,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == relative_position_bias &&
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
nullptr == past_key && nullptr == present_key &&
(nullptr == key_padding_mask || AttentionMaskType::MASK_1D_KEY_SEQ_LEN) &&
is_mask_none_or_1d_k_len &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
Expand Down
99 changes: 93 additions & 6 deletions onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ class SdpaKernel(IntEnum):
TRT_CAUSAL_ATTENTION = 128


# Since we support attention bias, so we only need support up to 2D mask.
class AttentionMaskFormat(IntEnum):
Mask_None = 0 # No attention mask.
Mask_1D_Key_SeqLen = 1 # Shape (batch_size), actual sequence lengths (excluding padding on the right side).
Mask_2D_Key_PaddingMask = 2 # Shape (batch_size, total_sequence_length), key padding mask mask.


class MultiHeadAttentionConfig:
def __init__(
self,
Expand All @@ -93,6 +100,7 @@ def __init__(
input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH,
verbose: bool = False,
has_bias: bool = False,
mask_format: int = AttentionMaskFormat.Mask_None,
):
self.operator = "MultiHeadAttention"
self.batch_size = batch_size
Expand Down Expand Up @@ -144,6 +152,19 @@ def __init__(
self.verbose = verbose
self.has_bias = has_bias

assert mask_format in [
AttentionMaskFormat.Mask_None,
AttentionMaskFormat.Mask_1D_Key_SeqLen,
AttentionMaskFormat.Mask_2D_Key_PaddingMask,
]
self.mask_format = mask_format

# mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None.
self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
self.mask_index_q = (
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length
)

def __repr__(self):
return (
f"MultiHeadAttentionConfig(batch_size={self.batch_size}, sequence_length={self.sequence_length}, "
Expand All @@ -154,7 +175,7 @@ def __repr__(self):
f"share_past_present_buffer={self.share_past_present_buffer}, "
f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, "
f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, "
f"has_bias={self.has_bias}"
f"has_bias={self.has_bias}, mask_format={self.mask_format}"
)

def shape_dict(self, input_format=None):
Expand Down Expand Up @@ -207,6 +228,13 @@ def shape_dict(self, input_format=None):
if self.has_bias:
shapes["bias"] = (3 * self.num_heads * self.head_size,)

if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
shapes["mask"] = (self.batch_size,)
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
shapes["mask"] = (self.batch_size, self.total_sequence_length)
else:
assert self.mask_format == AttentionMaskFormat.Mask_None

return shapes

def symbolic_shape_dict(self, input_format=None):
Expand Down Expand Up @@ -259,8 +287,35 @@ def symbolic_shape_dict(self, input_format=None):
if self.has_bias:
shapes["bias"] = (3 * self.num_heads * self.head_size,)

if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
shapes["mask"] = (self.batch_size,)
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
shapes["mask"] = (self.batch_size, "total_sequence_length")
else:
assert self.mask_format == AttentionMaskFormat.Mask_None

return shapes

def right_side_padding_masks(self):
q_mask = torch.ones(self.batch_size, 1, self.sequence_length, 1, dtype=torch.bool, device=self.device)
k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device)
mask = torch.ones(
self.batch_size,
self.num_heads,
self.sequence_length,
self.total_sequence_length,
dtype=torch.bool,
device=self.device,
)

if self.mask_format != AttentionMaskFormat.Mask_None:
for i, (m, n) in enumerate(zip(self.mask_index_q, self.mask_index_kv)):
q_mask[i, :, m:, :] = False
k_mask[i, :, n:, :] = False
mask[i, :, m:, :] = False
mask[i, :, :, n:] = False
return q_mask, k_mask, mask

def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False):
device = self.device
dtype = self.dtype
Expand Down Expand Up @@ -325,22 +380,51 @@ def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False):
if self.has_bias:
feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous()

# Generate padding mask
if self.mask_format != AttentionMaskFormat.Mask_None:
self.mask_index_kv = torch.randint(
1, self.total_sequence_length + 1, (self.batch_size,), dtype=torch.int32, device=self.device
)
if self.past_sequence_length > 0:
self.mask_index_q = (
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
)
else: # prompt case
self.mask_index_q = self.mask_index_kv.clone()

mask = None
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
mask = self.mask_index_kv.clone()
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device)
for i, n in enumerate(self.mask_index_kv):
k_mask[i, :, n:, :] = False
mask = k_mask.reshape(self.batch_size, self.total_sequence_length)
else:
assert self.mask_format == AttentionMaskFormat.Mask_None

if mask is not None:
feeds = {**feeds, "mask": mask.to(dtype=torch.int32)} # mask is int32 (not bool) for MultiHeadAttention op.

return feeds

def get_input_output_names(self):
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
return ["query", "key", "value"], ["output"]

if self.input_format == InputFormats.QKV_BSN3H:
inputs, outputs = ["query", "key", "value"], ["output"]
elif self.input_format == InputFormats.QKV_BSN3H:
inputs, outputs = ["query"], ["output"]
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
inputs, outputs = ["query", "key"], ["output"]
else:
inputs, outputs = ["query", "key", "value"], ["output"]

if self.has_bias:
assert self.input_format != InputFormats.Q_KV_BSNH_BSN2H
inputs = [*inputs, "bias"]

if self.mask_format != AttentionMaskFormat.Mask_None:
inputs = [*inputs, "mask"]

if self.has_past_input:
inputs = [*inputs, "past_key", "past_value"]

Expand All @@ -351,7 +435,7 @@ def get_input_output_names(self):


def fill_optional_mha_inputs(input_names):
inputs = ["query", "key", "value", "bias", "key_padding_mask", "relative_position_bias", "past_key", "past_value"]
inputs = ["query", "key", "value", "bias", "mask", "relative_position_bias", "past_key", "past_value"]

# Remove optional inputs that are not in input_names with empty string
inputs_with_optional = [input if input in input_names else "" for input in inputs]
Expand All @@ -376,13 +460,16 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use
num_heads=config.num_heads,
unidirectional=int(config.causal),
scale=config.softmax_scale,
mask_filter_value=float("-inf"),
domain="com.microsoft",
),
]

shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict()
inputs = [
helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name]))
helper.make_tensor_value_info(
input_name, TensorProto.INT32 if input_name == "mask" else float_type, list(shape_dict[input_name])
)
for input_name in input_names
if input_name
]
Expand Down
Loading

0 comments on commit ac38c44

Please sign in to comment.