From 9ca2473b6c210d0a3c3b57ba3cf9ffb284d4a31f Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 9 Nov 2023 17:50:37 +0100 Subject: [PATCH] Fix Falcon ONNX export with alibi (#1524) fix falcon export with alibi --- optimum/exporters/onnx/model_patcher.py | 37 +++++++++++++++++++++++-- tests/exporters/exporters_utils.py | 12 +++++++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index b9f0df29ea..0dec0db53c 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -15,12 +15,12 @@ import dataclasses import functools import inspect +import math import types from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union from packaging import version from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.falcon.modeling_falcon import build_alibi_tensor from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available @@ -310,7 +310,9 @@ def falcon_model_forward_without_kv_reformatting( attention_mask = attention_mask.to(hidden_states.device) if self.use_alibi: - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + # NOTE: we use a patched build_alibi_tensor. + alibi = falcon_build_alibi_tensor_patched(attention_mask, self.num_heads, dtype=hidden_states.dtype) + # alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None if position_ids is None: @@ -431,6 +433,37 @@ def __init__( self.original_make_causal = AttentionMaskConverter._make_causal_mask +def falcon_build_alibi_tensor_patched( + attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype +) -> torch.Tensor: + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + # NOTE: remove the .bfloat16() cast here as PyTorch ONNX export rather casts to complex128 if this is used, resulting in a onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph error. + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + class FalconModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index e0640b7657..f17184a1b7 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -72,7 +72,17 @@ ], "mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"], }, - "falcon": "fxmarty/really-tiny-falcon-testing", + "falcon": { + "fxmarty/really-tiny-falcon-testing": [ + "feature-extraction", + "feature-extraction-with-past", + "question-answering", + "text-generation", + "text-generation-with-past", + "token-classification", + ], + "fxmarty/tiny-testing-falcon-alibi": ["text-generation", "text-generation-with-past"], + }, "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",