Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Gemma-2-2b-it output logit doesn't match with huggingface #693

Open
1 task done
yeutong opened this issue Aug 2, 2024 · 3 comments · Fixed by #694
Open
1 task done

[Bug Report] Gemma-2-2b-it output logit doesn't match with huggingface #693

yeutong opened this issue Aug 2, 2024 · 3 comments · Fixed by #694
Labels
complexity-high Very complicated changes for people to address who are quite familiar with the code implementation-inaccuracy Any issues related to our implementation being off from the official version

Comments

@yeutong
Copy link

yeutong commented Aug 2, 2024

Describe the bug
The output logits from transformer_lens and huggingface are quite different using Gemma-2-2b-it model

Code example

import torch
import transformer_lens
from transformers import AutoTokenizer, AutoModelForCausalLM

device = 'cuda'
model_name = 'google/gemma-2-2b-it'
tl_model = transformer_lens.HookedTransformer.from_pretrained(model_name, device=device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

inputs = tokenizer('Hello world', return_tensors="pt").to(device)

logits_tl = tl_model(inputs.input_ids, return_type='logits', prepend_bos=False)
logits_hf = hf_model(**inputs).logits

print((logits_tl[0, -1] - logits_hf[0, -1]).mean()) # 0.1159
print((logits_hf[0, -1]).min(), (logits_hf[0, -1]).max()) # -19.6916 16.0789

System Info
transformer_lens 2.3.0, transformers 4.43.2

Additional context
The logit diff is quite large

Checklist

  • I have checked that there is no similar issue in the repo (required)
@neelnanda-io
Copy link
Collaborator

neelnanda-io commented Aug 2, 2024 via email

@yeutong
Copy link
Author

yeutong commented Aug 2, 2024

Tried from_pretrained_no_processing and got the same results. It is more than the unembedding centering, the differences exist and get larger in each layer model activations.

def forward_with_cache(model, layer, inputs):
    cache = None
    def hook(module, inputs, outputs):
        nonlocal cache
        cache = inputs[0]
        return outputs
    
    hook_handle = model.model.layers[layer].register_forward_hook(hook)
    _ = model(**inputs)
    hook_handle.remove()

    return cache

resid_pre_diffs = []

for layer in range(tl_model.cfg.n_layers):
    hf_cache = forward_with_cache(hf_model, layer, inputs)
    _, tl_cache = tl_model.run_with_cache(inputs.input_ids, prepend_bos=False, names_filter=[f'blocks.{layer}.hook_resid_pre'])
    tl_cache = tl_cache[f'blocks.{layer}.hook_resid_pre']
    resid_pre_diff = (hf_cache - tl_cache)[0, -1].norm().item()
    resid_pre_diffs.append(resid_pre_diff)

import plotly.express as px
px.line(resid_pre_diffs, markers=True, labels={'index': 'Layer', 'value': 'norm of resid pre diff'}, title='Difference in resid_pre between HF and TL')
image

@mntss
Copy link
Contributor

mntss commented Aug 7, 2024

@yeutong the issue is caused by a different attention scale used (~14.96 vs 16). The HF implementation also disables the attention logits soft capping for inference, but that is less important

for b in tl_model.blocks:
    b.attn.attn_scale = 16
    b.attn.cfg.attn_scores_soft_cap = 0
resid_diff

There is still some difference in the activations, but this is on the order of 5e-4 on the last layer. This one is probably a deeper issue related to the use of einsum in the attention

@mntss mntss mentioned this issue Aug 7, 2024
7 tasks
@bryce13950 bryce13950 added complexity-high Very complicated changes for people to address who are quite familiar with the code implementation-inaccuracy Any issues related to our implementation being off from the official version labels Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complexity-high Very complicated changes for people to address who are quite familiar with the code implementation-inaccuracy Any issues related to our implementation being off from the official version
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants