From eee160f5c002e2d576d7648173682fffa47f42ee Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:15:29 -0700 Subject: [PATCH] Disable sliding window in GQA (#282) ### Description This PR disables the sliding window attribute in the GroupQueryAttention (GQA) op. ### Motivation and Context This unblocks some CI pipelines and allows models with GQA to run successfully on more machines and environments. --- src/python/py/models/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 1d4e46757..c91cc6ce3 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -882,7 +882,7 @@ def make_group_query_attention(self, name, **kwargs): outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] self.make_node( "GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", - num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, local_window_size=self.window_size, + num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, # local_window_size=self.window_size, # Disable sliding window attribute temporarily do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"], ) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads])