Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Granite support #1218

Open
wants to merge 7 commits into
base: nightly
Choose a base branch
from
Open

Granite support #1218

wants to merge 7 commits into from

Conversation

Datta0
Copy link
Contributor

@Datta0 Datta0 commented Oct 29, 2024

No description provided.

@Datta0 Datta0 marked this pull request as ready for review October 31, 2024 14:15
Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work again! Just some comments :)

Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
sw = getattr(self.config, "sliding_window", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Datta0 Is sliding window attention necessary for Granite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, not necessary. Removing it.

if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

assert position_embeddings is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember you said we must pass in the position embeddings - did we calculate the cos and sine matrices in RoPE incorrectly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is just a validation. We are calculating the sin, cos and passing from here.

pass


def GraniteDecoderLayer_fast_forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we inherit from LlamaDecoderLayer_fast_forward? [Actually scratch that - I forgot Granite has a residual multiplier]

I'm assuming it's because of position_embeddings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings = position_embeddings,
_flag_for_generation=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think flagging it for generation is a good idea - we dynamically have to set this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is inspired from gemma2. Should we set it to what we see in the config?

Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)

# Handle sliding windows
sliding_window = self.config.sliding_window if hasattr(self.config, "sliding_window") else self.config.max_position_embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is SWA necessary in Granite?

do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
position_embeddings = position_embeddings,
)
hidden_states = residual + hidden_states * self.config.residual_multiplier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we could use addmm to fuse this entirely into 1 op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I resorted to using torch.add cuz we don't have any matmul here. Thanks for the suggestion :)



@staticmethod
def post_patch(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can ignore this I think (if it's a copy from Llama) - it should auto inherit it (I think)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong but Wouldn't tie word embeddings mandate handling this separately?

@@ -617,6 +617,7 @@ def LlamaModel_fast_forward(
IS_GEMMA = self.config.model_type.startswith("gemma")
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
IS_COHERE = self.config.model_type.startswith("cohere")
IS_GRANITE = self.config.model_type.startswith("granite")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix up spacing to make all the equal signs spaced evenly :)

@@ -763,6 +766,12 @@ def LlamaModel_fast_forward(
pass
pass


if IS_GRANITE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a must must?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah iirc, Granit's forward calculates it here and passes on and without this it throws error (I don't exactly remember the error unfortunately)

@@ -974,6 +986,9 @@ def _CausalLM_fast_forward(
loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
logit_scaling = getattr(self.config, "logit_scale", 0)
if self.config.model_type == "granite":
# granite uses logit_scaling as key and they divide by the scale unlike cohere
logit_scaling = 1 / getattr(self.config, "logits_scaling", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting - can you confirm it's not Cohere type logit scaling thanks :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • granite uses logit_scaling as key and they divide by the scale unlike cohere
    notice that for granite, logits_scale is 16 and for cohere it is 0.125 (aka 1/8) in their respective configs
  • granite
  • cohere

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants