Skip to content

Commit

Permalink
Gemma faster inference (#987)
Browse files Browse the repository at this point in the history
* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* update token retrieval logic (#952)

* Fix DPO (#947)

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* update hf token retrieval logic

---------

Co-authored-by: Daniel Han <[email protected]>

* Update llama.py

* get_token

* Update README.md

* Update gemma2.py

* Update rms_layernorm.py

* synchronize

* Update gemma2.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* layernorm

* Update rms_layernorm.py

* Update gemma2.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* revert

* Gemma

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update gemma2.py

* Change UnslothTrainingArguments base class to SFTConfig (#979)

* Cohere

* Update trainer.py

* Cohere

* Cohere

* New models

* Update llama.py

* Update llama.py

* Update cohere.py

* Update llama.py

* Update cohere.py

* retry

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* _apply_lora_mlp

* Update _utils.py

* Gemma fixes

* Update llama.py

* Update flex_attention.py

---------

Co-authored-by: Hafedh <[email protected]>
Co-authored-by: Tuan Pham <[email protected]>
  • Loading branch information
3 people authored Sep 3, 2024
1 parent c085a45 commit 480f2ef
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 2 deletions.
6 changes: 5 additions & 1 deletion unsloth/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
)
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora

from .flex_attention import HAS_FLEX_ATTENTION, slow_attention_softcapping
from .flex_attention import (
HAS_FLEX_ATTENTION,
slow_attention_softcapping,
slow_inference_attention_softcapping,
)

if HAS_FLEX_ATTENTION:
from .flex_attention import (
Expand Down
37 changes: 37 additions & 0 deletions unsloth/kernels/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,40 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
return A
pass


torch_matmul = torch.matmul
torch_tanh = torch.tanh
torch_nn_functional_softmax = torch.nn.functional.softmax
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups

# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)

# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping

Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch_matmul(Q, K.transpose(2, 3))

# Logit softcapping
A /= t; torch_tanh(A, out = A); A *= t;
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch_matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
4 changes: 4 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
"accelerate_old_send_to_device",
"accelerate_new_send_to_device",
]

import torch
Expand Down Expand Up @@ -287,6 +289,7 @@ def patch_mistral_nemo_config(config):
import accelerate.utils.operations
if hasattr(accelerate.utils.operations, "send_to_device") and \
accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
from accelerate.utils.operations import *
send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
send_to_device = re.sub(
Expand All @@ -296,6 +299,7 @@ def patch_mistral_nemo_config(config):
).replace("def send_to_device", "def _fixed_send_to_device")
exec(send_to_device)
# accelerate.utils.operations.send_to_device = _fixed_send_to_device
accelerate_new_send_to_device = _fixed_send_to_device
pass
pass
# =============================================
Expand Down
6 changes: 5 additions & 1 deletion unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ def Gemma2Attention_fast_forward(
A = A.reshape(bsz, q_len, n_heads*head_dim)
else:
mask = causal_mask if attention_mask is None else attention_mask
A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)
fx = slow_inference_attention_softcapping \
if "_flag_for_generation" in kwargs else \
slow_attention_softcapping
A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)
pass
A = self.apply_o(self, A)
return A, None, past_key_value
Expand Down Expand Up @@ -192,6 +195,7 @@ def Gemma2DecoderLayer_fast_forward(
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
_flag_for_generation=True,
)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
Expand Down
11 changes: 11 additions & 0 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,8 @@ def _CausalLM_fast_forward(
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
elif num_logits_to_keep != 0:
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
else:
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
Expand Down Expand Up @@ -1368,8 +1370,14 @@ def _fast_generate(*args, **kwargs):
pass
internal_model._flag_for_generation = True

# Must patch accelerate for Xformers
import accelerate.utils.operations
accelerate.utils.operations.send_to_device = accelerate_new_send_to_device

# For newer HF
kwargs["cache_implementation"] = "dynamic"
# For num_logits_to_keep
kwargs["num_logits_to_keep"] = 1

# Remove token_type_ids
kwargs.pop("token_type_ids", None)
Expand Down Expand Up @@ -1402,6 +1410,9 @@ def _fast_generate(*args, **kwargs):
pass
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation

# Return accelerate back
accelerate.utils.operations.send_to_device = accelerate_old_send_to_device

return output
pass
return _fast_generate
Expand Down

0 comments on commit 480f2ef

Please sign in to comment.