-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Granite support #1218
base: nightly
Are you sure you want to change the base?
Granite support #1218
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work again! Just some comments :)
unsloth/models/granite.py
Outdated
Q = Q.transpose(1, 2) | ||
K = K.transpose(1, 2) | ||
V = V.transpose(1, 2) | ||
sw = getattr(self.config, "sliding_window", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Datta0 Is sliding window attention necessary for Granite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Um, not necessary. Removing it.
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
|
||
assert position_embeddings is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember you said we must pass in the position embeddings - did we calculate the cos and sine matrices in RoPE incorrectly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is just a validation. We are calculating the sin, cos and passing from here.
pass | ||
|
||
|
||
def GraniteDecoderLayer_fast_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we inherit from LlamaDecoderLayer_fast_forward
? [Actually scratch that - I forgot Granite has a residual multiplier]
I'm assuming it's because of position_embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
unsloth/models/granite.py
Outdated
use_cache=use_cache, | ||
padding_mask=padding_mask, | ||
position_embeddings = position_embeddings, | ||
_flag_for_generation=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think flagging it for generation is a good idea - we dynamically have to set this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is inspired from gemma2. Should we set it to what we see in the config?
unsloth/models/granite.py
Outdated
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) | ||
|
||
# Handle sliding windows | ||
sliding_window = self.config.sliding_window if hasattr(self.config, "sliding_window") else self.config.max_position_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is SWA necessary in Granite?
unsloth/models/granite.py
Outdated
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), | ||
position_embeddings = position_embeddings, | ||
) | ||
hidden_states = residual + hidden_states * self.config.residual_multiplier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically we could use addmm
to fuse this entirely into 1 op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I resorted to using torch.add cuz we don't have any matmul here. Thanks for the suggestion :)
|
||
|
||
@staticmethod | ||
def post_patch(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can ignore this I think (if it's a copy from Llama) - it should auto inherit it (I think)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong but Wouldn't tie word embeddings mandate handling this separately?
@@ -617,6 +617,7 @@ def LlamaModel_fast_forward( | |||
IS_GEMMA = self.config.model_type.startswith("gemma") | |||
IS_GEMMA2 = self.config.model_type.startswith("gemma2") | |||
IS_COHERE = self.config.model_type.startswith("cohere") | |||
IS_GRANITE = self.config.model_type.startswith("granite") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix up spacing to make all the equal signs spaced evenly :)
@@ -763,6 +766,12 @@ def LlamaModel_fast_forward( | |||
pass | |||
pass | |||
|
|||
|
|||
if IS_GRANITE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is a must must?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah iirc, Granit's forward calculates it here and passes on and without this it throws error (I don't exactly remember the error unfortunately)
@@ -974,6 +986,9 @@ def _CausalLM_fast_forward( | |||
loss = None | |||
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) | |||
logit_scaling = getattr(self.config, "logit_scale", 0) | |||
if self.config.model_type == "granite": | |||
# granite uses logit_scaling as key and they divide by the scale unlike cohere | |||
logit_scaling = 1 / getattr(self.config, "logits_scaling", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting - can you confirm it's not Cohere type logit scaling thanks :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.