Skip to content

Commit

Permalink
add dim=32 and refine test
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Dec 15, 2023
1 parent d3bc5e6 commit 2580583
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
if (rotary_embedding != 64 && rotary_embedding != 128) {
if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) {
ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
Expand Down
34 changes: 19 additions & 15 deletions onnxruntime/test/python/transformers/test_parity_neox_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def create_neox_attention_graph(
qkv_weight,
qkv_bias,
num_heads,
rotary_embedding,
):
nodes = [
helper.make_node(
Expand All @@ -43,7 +44,7 @@ def create_neox_attention_graph(
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
rotary_embedding = 64,
rotary_embedding=rotary_embedding,
domain="com.microsoft",
),
]
Expand Down Expand Up @@ -175,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):


class GPTNeoXAttention(nn.Module):
def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64):
super().__init__()
self.do_rotary = True
self.num_attention_heads = num_head
self.hidden_size = hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = 64
self.rotary_ndims = rotary_ndims
max_positions = 2048
self.register_buffer(
"bias",
Expand All @@ -198,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
# self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size))))

if past_seq_len > 0:
assert self.rotary_ndims == self.head_size
self.onnx_graph = create_neox_decoder_masked_self_attention_graph(
batch_size,
seq_len,
Expand All @@ -221,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
.transpose(0, 1),
self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1),
self.num_attention_heads,
self.rotary_ndims,
)

@classmethod
Expand Down Expand Up @@ -423,21 +426,22 @@ def test_gpt_neox_attention(self):
for batch_size in [1, 2, 4, 8]:
for seq_len in [32, 128, 512, 1024, 2048]:
for num_head in [12]:
for hidden_size in [768, 960]:
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)
for rotary_ndims in [32, 64]:
for hidden_size in [768, 960]:
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims)

hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
torch.float32
)

torch_output = attn.torch_forward(hidden_states)
ort_output = attn.onnx_forward(hidden_states)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-3)
print(
f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}"
hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
torch.float32
)

torch_output = attn.torch_forward(hidden_states)
ort_output = attn.onnx_forward(hidden_states)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-3)
print(
f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}"
)

def test_gpt_neox_decoder_masked_self_attention(self):
for batch_size in [1, 2, 4, 8]:
for past_seq_len in [1, 4, 32, 128, 512, 1024]:
Expand Down

0 comments on commit 2580583

Please sign in to comment.