Skip to content

Commit

Permalink
Update llama.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Aug 11, 2024
1 parent eab5319 commit 394156d
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,6 @@ def LlamaModel_fast_forward(
past_key_values_length,
sliding_window = getattr(self.config, "sliding_window", None),
)
print(attention_mask)
pass

hidden_states = inputs_embeds
Expand Down Expand Up @@ -683,11 +682,26 @@ def LlamaModel_fast_forward(


# Gemma2 has alternating SWA and global attn
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
if IS_GEMMA2:
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
self.SWA_mask = True
self.GA_mask = False
else:
elif attention_mask is not None:
self.SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = self.config.sliding_window,
)
self.GA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = None,
)
elif not hasattr(self, "SWA_mask"):
n = self.max_seq_length # self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
Expand Down

0 comments on commit 394156d

Please sign in to comment.