Skip to content

Commit

Permalink
Disable sliding window in GQA (#282)
Browse files Browse the repository at this point in the history
### Description

This PR disables the sliding window attribute in the GroupQueryAttention
(GQA) op.

### Motivation and Context

This unblocks some CI pipelines and allows models with GQA to run
successfully on more machines and environments.
  • Loading branch information
kunal-vaishnavi authored Apr 22, 2024
1 parent cf18f30 commit eee160f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def make_group_query_attention(self, name, **kwargs):
outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")]
self.make_node(
"GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft",
num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, local_window_size=self.window_size,
num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, # local_window_size=self.window_size, # Disable sliding window attribute temporarily
do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads])
Expand Down

0 comments on commit eee160f

Please sign in to comment.