diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a522424a..2a07da6c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -638,7 +638,6 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) - print(attention_mask) pass hidden_states = inputs_embeds @@ -683,11 +682,26 @@ def LlamaModel_fast_forward( # Gemma2 has alternating SWA and global attn - if IS_GEMMA2 and not hasattr(self, "SWA_mask"): + if IS_GEMMA2: if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None: self.SWA_mask = True self.GA_mask = False - else: + elif attention_mask is not None: + self.SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window = self.config.sliding_window, + ) + self.GA_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window = None, + ) + elif not hasattr(self, "SWA_mask"): n = self.max_seq_length # self.config.max_position_embeddings # masked_fill is making stuff slower! # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)