-
Notifications
You must be signed in to change notification settings - Fork 288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug Report] Gemma-2-2b-it output logit doesn't match with huggingface #693
Labels
complexity-high
Very complicated changes for people to address who are quite familiar with the code
implementation-inaccuracy
Any issues related to our implementation being off from the official version
Comments
TransformerLens centers the unembedding, which translates every logit by a
fixed amount per token (the shift can vary over token). Can you do this
again for log probs? Or try from_pretrained_no_processing? There are known
accuracy issues, but I want to rule out trivial causes
…On Fri, 2 Aug 2024, 7:53 pm Yeu-Tong Lau, ***@***.***> wrote:
*Describe the bug*
The output logits from transformer_lens and huggingface are quite
different using Gemma-2-2b-it model
*Code example*
import torchimport transformer_lensfrom transformers import AutoTokenizer, AutoModelForCausalLM
device = 'cuda'model_name = 'google/gemma-2-2b-it'tl_model = transformer_lens.HookedTransformer.from_pretrained(model_name, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)hf_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
inputs = tokenizer('Hello world', return_tensors="pt").to(device)
logits_tl = tl_model(inputs.input_ids, return_type='logits', prepend_bos=False)logits_hf = hf_model(**inputs).logits
print((logits_tl[0, -1] - logits_hf[0, -1]).mean()) # 0.1159print((logits_hf[0, -1]).min(), (logits_hf[0, -1]).max()) # -19.6916 16.0789
*System Info*
transformer_lens 2.3.0, transformers 4.43.2
*Additional context*
The logit diff is quite large
Checklist
- I have checked that there is no similar issue
<https://github.com/TransformerLensOrg/TransformerLens/issues> in the
repo (*required*)
—
Reply to this email directly, view it on GitHub
<#693>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ASRPNKNUTFZ6X5ECWWOOHYDZPPIRNAVCNFSM6AAAAABL5BOXL6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGQ2DKNJXHA4TMNQ>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
@yeutong the issue is caused by a different attention scale used (~14.96 vs 16). The HF implementation also disables the attention logits soft capping for inference, but that is less important
There is still some difference in the activations, but this is on the order of 5e-4 on the last layer. This one is probably a deeper issue related to the use of einsum in the attention |
bryce13950
added
complexity-high
Very complicated changes for people to address who are quite familiar with the code
implementation-inaccuracy
Any issues related to our implementation being off from the official version
labels
Aug 16, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
complexity-high
Very complicated changes for people to address who are quite familiar with the code
implementation-inaccuracy
Any issues related to our implementation being off from the official version
Describe the bug
The output logits from transformer_lens and huggingface are quite different using Gemma-2-2b-it model
Code example
System Info
transformer_lens 2.3.0, transformers 4.43.2
Additional context
The logit diff is quite large
Checklist
The text was updated successfully, but these errors were encountered: