diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 88732f2aeb2ed..603236b421630 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -650,7 +650,7 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (rotary_embedding != 64 && rotary_embedding != 128) { + if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) { ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py index fbd3b2b3a0171..a98bb623beaea 100644 --- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py @@ -29,6 +29,7 @@ def create_neox_attention_graph( qkv_weight, qkv_bias, num_heads, + rotary_embedding, ): nodes = [ helper.make_node( @@ -43,7 +44,7 @@ def create_neox_attention_graph( num_heads=num_heads, unidirectional=1, do_rotary=1, - rotary_embedding = 64, + rotary_embedding=rotary_embedding, domain="com.microsoft", ), ] @@ -175,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): class GPTNeoXAttention(nn.Module): - def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): + def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64): super().__init__() self.do_rotary = True self.num_attention_heads = num_head self.hidden_size = hidden_size self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = 64 + self.rotary_ndims = rotary_ndims max_positions = 2048 self.register_buffer( "bias", @@ -198,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): # self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size)))) if past_seq_len > 0: + assert self.rotary_ndims == self.head_size self.onnx_graph = create_neox_decoder_masked_self_attention_graph( batch_size, seq_len, @@ -221,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): .transpose(0, 1), self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1), self.num_attention_heads, + self.rotary_ndims, ) @classmethod @@ -423,21 +426,22 @@ def test_gpt_neox_attention(self): for batch_size in [1, 2, 4, 8]: for seq_len in [32, 128, 512, 1024, 2048]: for num_head in [12]: - for hidden_size in [768, 960]: - attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size) + for rotary_ndims in [32, 64]: + for hidden_size in [768, 960]: + attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims) - hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to( - torch.float32 - ) - - torch_output = attn.torch_forward(hidden_states) - ort_output = attn.onnx_forward(hidden_states) - if ort_output is not None: - assert torch.allclose(torch_output, ort_output, atol=1e-3) - print( - f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}" + hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to( + torch.float32 ) + torch_output = attn.torch_forward(hidden_states) + ort_output = attn.onnx_forward(hidden_states) + if ort_output is not None: + assert torch.allclose(torch_output, ort_output, atol=1e-3) + print( + f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}" + ) + def test_gpt_neox_decoder_masked_self_attention(self): for batch_size in [1, 2, 4, 8]: for past_seq_len in [1, 4, 32, 128, 512, 1024]: