diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 72c79f4b6..3d01770e3 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -158,7 +158,6 @@ def update( value_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor, - position_ids: torch.Tensor, input_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -175,7 +174,7 @@ def update( A tuple containing the updated key and value states. """ - batch_size = position_ids.shape[0] + batch_size = input_lens.shape[-1] if self.get_seq_length() == 0: # prefill self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 5a8009774..9f68074c7 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,8 +13,8 @@ # limitations under the License. from transformers.models.bert.modeling_bert import BertIntermediate -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block +from transformers.models.falcon.modeling_falcon import FalconModel, FalconDecoderLayer +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,13 +27,14 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, - _gpt2_block_forward, _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, _IPEXGPT2Attention, _IPEXIntermediate, _IPEXLlamaDecoderLayer, _llama_model_forward, + _falcon_model_forward, + _gpt2_model_forward, ) @@ -90,7 +91,9 @@ def _patch_falcon_model(model): 2. Use IPEX Rope and paged cache 3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add) """ - model.transformer._use_sdpa = False + num_key_value_heads = model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1 + setattr(model.config, "num_key_value_heads", num_key_value_heads) + convert_functions(model, FalconModel, "forward", _falcon_model_forward) replace_customized_linear_with_linear(model) convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config) return model @@ -102,9 +105,10 @@ def _patch_gpt2_model(model): 1. Disable SDPA so the attention mask will be compatible to ipex attention. 2. Use IAKV cache """ - model.transformer._attn_implementation = "eager" + num_key_value_heads = model.config.num_attention_heads + setattr(model.config, "num_key_value_heads", num_key_value_heads) + convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) - convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 5555b0b80..b8e92be63 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -20,7 +20,7 @@ from torch import nn from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2Block from optimum.intel.utils.import_utils import is_ipex_version @@ -182,7 +182,10 @@ def _llama_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + setattr(past_key_values, "input_lens", input_lens) + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -195,7 +198,6 @@ def _llama_model_forward( output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, - input_lens=input_lens, ) hidden_states = layer_outputs[0] @@ -213,30 +215,268 @@ def _llama_model_forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) if hidden_states.shape[0] != batch_size * seq_length: (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states hidden_states = hidden_states_copy + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( - last_hidden_state=hidden_states.view(batch_size, -1, hidden_states.shape[-1]), + last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) -def _gpt2_block_forward(self, hidden_states, *args, **kwargs): - attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: - bsz, seq_len, _ = hidden_states.size() - layer_past = kwargs.get("layer_past", None) - past_len = layer_past[0].size(-2) if layer_past is not None else 0 - attention_mask = (1 - attention_mask / torch.finfo(attention_mask.dtype).min).squeeze(1, 2) - attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (bsz, seq_len), hidden_states, past_len) - kwargs["attention_mask"] = attention_mask +# Adapted from https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/falcon/modeling_falcon.py#L945 +def _falcon_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + batch_size, seq_length, _ = inputs_embeds.shape + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + if past_key_values_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + setattr(past_key_values, "input_lens", input_lens) + + next_decoder_cache = None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) - return GPT2Block.forward(self, hidden_states, *args, **kwargs) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=None, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = outputs[0] + if use_cache is True: + next_decoder_cache = outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def _gpt2_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + batch_size, seq_length, _ = inputs_embeds.shape + position_embeddings = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeddings + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + if past_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + setattr(past_key_values, "input_lens", input_lens) + + presents = None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) class _IPEXAttention(nn.Module): @@ -256,8 +496,8 @@ def qkv_gemm(self, hidden_states): def rope(self, *args, **kwargs): raise NotImplementedError("Need to implement in specific model class") - def postprocess_attention_output(self, attn_output, bsz, seq_len): - attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) + def postprocess_attention_output(self, attn_output): + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output def forward( @@ -268,25 +508,20 @@ def forward( past_key_value: Optional[IPEXPagedCache] = None, output_attentions: bool = False, use_cache: bool = False, - input_lens: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) - bsz, seq_len = position_ids.size() + input_lens = getattr(past_key_value, "input_lens", None) past_len = 0 if past_key_value is not None: past_len = past_key_value.get_seq_length() - qkv_out = self.qkv_gemm(hidden_states) - if isinstance(qkv_out, tuple) and len(qkv_out) == 3: - query, key, value = qkv_out[0], qkv_out[1], qkv_out[2] - query, key = self.rope(query, key, **kwargs) - else: - query, key, value = self.rope(qkv_out, **kwargs) + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, **kwargs) if past_key_value is not None: key_cache, value_cache = past_key_value.update( - key, value, self.layer_idx, attention_mask, position_ids, input_lens + key, value, self.layer_idx, attention_mask, input_lens ) attn_output = torch.empty_like(query) @@ -325,7 +560,7 @@ def forward( None, ) - attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) + attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: attn_weights = None @@ -373,36 +608,49 @@ def rope(self, query, key, **kwargs): class _IPEXFalconAttention(_IPEXAttention): + def __init__(self, module, config): + self.num_key_value_heads = config.num_key_value_heads + super().__init__(module, config) + self.q_slice = self.head_dim * config.num_kv_heads + self.k_slice = self.q_slice + self.head_dim + self.v_slice = self.k_slice + self.head_dim + def qkv_gemm(self, hidden_states): - return self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + qkv_out = self.query_key_value(hidden_states) + if self.new_decoder_architecture: + qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) + query = qkv_out[:, :, :-2, :].flatten(1, 2) + key = qkv_out[:, :, [-2], :].flatten(1, 2) + value = qkv_out[:, :, [-1], :].flatten(1, 2) + else: + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) + return query, key, value - def rope(self, fused_qkv, **kwargs): + def rope(self, query, key, **kwargs): position_embeddings = kwargs.pop("position_embeddings", None) - (query, key, value) = self._split_heads(fused_qkv) rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) - return query, key, value + return query, key class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: + self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) - def _split_heads_ipex(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - return tensor.view(new_shape) # (batch, seq_length, head, head_features) - def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads_ipex(query, self.num_heads, self.head_dim) - key = self._split_heads_ipex(key, self.num_heads, self.head_dim) - value = self._split_heads_ipex(value, self.num_heads, self.head_dim) + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query = query.view(-1, self.num_heads, self.head_dim) + key = key.view(-1, self.num_heads, self.head_dim) + value = value.view(-1, self.num_heads, self.head_dim) return query, key, value def rope(self, query, key, *args, **kwargs): return query, key - def postprocess_attention_output(self, attn_output, bsz, seq_len): - attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.embed_dim) + def postprocess_attention_output(self, attn_output): + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 270a9b32d..d34b4f3c4 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -230,7 +230,6 @@ def _from_pretrained( } task = cls.export_feature - config.torch_dtype = torch_dtype model = TasksManager.get_model_from_task( task, model_id, @@ -240,6 +239,7 @@ def _from_pretrained( _commit_hash=commit_hash, **model_kwargs, ) + config = model.config return cls(model, config=config, export=True, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): @@ -305,11 +305,7 @@ def can_generate(self): return isinstance(self, GenerationMixin) def _call_model(self, *args, **kwargs): - try: - with torch.autocast(self.device.type, self.dtype), torch.no_grad(): - out = self.model(*args, **kwargs) - except RuntimeError: - out = self.model(*args, **kwargs) + out = self.model(*args, **kwargs) return out def _init_warmup(self):