Skip to content

Commit

Permalink
Fix Falcon ONNX export with alibi (#1524)
Browse files Browse the repository at this point in the history
fix falcon export with alibi
  • Loading branch information
fxmarty authored Nov 9, 2023
1 parent 41347fc commit 9ca2473
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
37 changes: 35 additions & 2 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import dataclasses
import functools
import inspect
import math
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

from packaging import version
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon.modeling_falcon import build_alibi_tensor
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from transformers.utils import is_torch_available

Expand Down Expand Up @@ -310,7 +310,9 @@ def falcon_model_forward_without_kv_reformatting(
attention_mask = attention_mask.to(hidden_states.device)

if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# NOTE: we use a patched build_alibi_tensor.
alibi = falcon_build_alibi_tensor_patched(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
Expand Down Expand Up @@ -431,6 +433,37 @@ def __init__(
self.original_make_causal = AttentionMaskConverter._make_causal_mask


def falcon_build_alibi_tensor_patched(
attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
) -> torch.Tensor:
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
# NOTE: remove the .bfloat16() cast here as PyTorch ONNX export rather casts to complex128 if this is used, resulting in a onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph error.
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)


class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
Expand Down
12 changes: 11 additions & 1 deletion tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,17 @@
],
"mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"],
},
"falcon": "fxmarty/really-tiny-falcon-testing",
"falcon": {
"fxmarty/really-tiny-falcon-testing": [
"feature-extraction",
"feature-extraction-with-past",
"question-answering",
"text-generation",
"text-generation-with-past",
"token-classification",
],
"fxmarty/tiny-testing-falcon-alibi": ["text-generation", "text-generation-with-past"],
},
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
Expand Down

0 comments on commit 9ca2473

Please sign in to comment.