From e222963307c179e888443ec066dc1b3e7b580034 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:35:25 -0700 Subject: [PATCH] Add logit softcapping to GQA (#876) ### Description This PR adds the `softcap` attribute to the `GroupQueryAttention` op. ### Motivation and Context This PR helps resolve the `NaN` output issue with Gemma-2 raised in [this issue](https://github.com/microsoft/onnxruntime-genai/issues/692). --- src/python/py/models/builder.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 7951de8e2..6c54bd429 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -147,6 +147,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } # LayerNorm-specific variables + epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06 self.layernorm_attrs = { "simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm "first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms @@ -156,6 +157,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "output_0": "", # Output 0 for LayerNorm and SkipLayerNorm "output_3": "", # Output 3 for SkipLayerNorm "add_offset": 0, # Offset value for LayerNorm weight + "epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm } # MatMul-specific variables @@ -212,6 +214,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) + softcap = config.attn_logit_softcapping if hasattr(config, "attn_logit_softcapping") else 0.0 # default is 0.0 in GroupQueryAttention kernel + # Block-sparse attention-specific variables sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0 kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0 @@ -224,6 +228,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "v_path": "", # V path to attention "op_type": "MultiHeadAttention", # Attention op to use "scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention + "softcap": softcap, # Softcap value to prevent values from exploding in attention "use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op) "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) "block_sparse": { # Block-sparse attention-specific variables @@ -969,7 +974,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location): name = f"/model/layers.{layer_id}/{location}_layernorm/{'Skip' if skip else ''}LayerNorm" op_type = f"{'Skip' if skip else ''}{'Simplified' if simple else ''}LayerNormalization" - kwargs = {"epsilon": 9.999999747378752e-06} + kwargs = {"epsilon": self.layernorm_attrs["epsilon"]} if not skip: kwargs.update({"axis": -1, "stash_type": 1}) @@ -1381,7 +1386,7 @@ def make_group_query_attention(self, name, **kwargs): self.make_node( "GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], # local_window_size=self.window_size, # Disable sliding window attribute temporarily - do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"], + softcap=self.attention_attrs["softcap"], do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"], ) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads])