Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable gpt2, falcon has core dump error in PagedAttention.single_quer… #979

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def update(
value_states: torch.Tensor,
layer_idx: int,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
input_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -175,7 +174,7 @@ def update(
A tuple containing the updated key and value states.
"""

batch_size = position_ids.shape[0]
batch_size = input_lens.shape[-1]
if self.get_seq_length() == 0:
# prefill
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)
Expand Down
16 changes: 10 additions & 6 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
from transformers.models.falcon.modeling_falcon import FalconModel, FalconDecoderLayer
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -27,13 +27,14 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_gpt2_block_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
_IPEXIntermediate,
_IPEXLlamaDecoderLayer,
_llama_model_forward,
_falcon_model_forward,
_gpt2_model_forward,
)


Expand Down Expand Up @@ -90,7 +91,9 @@ def _patch_falcon_model(model):
2. Use IPEX Rope and paged cache
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
"""
model.transformer._use_sdpa = False
num_key_value_heads = model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
return model
Expand All @@ -102,9 +105,10 @@ def _patch_gpt2_model(model):
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IAKV cache
"""
model.transformer._attn_implementation = "eager"
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
return model


Expand Down
Loading