Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Suggestion]: Support Causal Tracing for LLaMA model #174

Open
aryopg opened this issue Jul 18, 2024 · 0 comments
Open

[Suggestion]: Support Causal Tracing for LLaMA model #174

aryopg opened this issue Jul 18, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@aryopg
Copy link
Contributor

aryopg commented Jul 18, 2024

Suggestion / Feature Request

I've tried modifying the embed_to_distrib function in pyvene/models/basic_utils.py to also support llama models as such:

def embed_to_distrib(model, embed, log=False, logits=False):
    """Convert an embedding to a distribution over the vocabulary"""
    if "gpt2" in model.config.architectures[0].lower():
        with torch.inference_mode():
            vocab = torch.matmul(embed, model.wte.weight.t())
            if logits:
                return vocab
            return lsm(vocab) if log else sm(vocab)
    elif "llama" in model.config.architectures[0].lower():
        with torch.inference_mode():
            vocab = model.lm_head(embed)
            if logits:
                return vocab
            return lsm(vocab) if log else sm(vocab)

It seems to work fine when doing causal tracing (see images below):

  • Single restored layer
    Screenshot 2024-07-18 at 10 41 40

  • MLP layer
    Screenshot 2024-07-18 at 10 42 27

  • Attention layer
    Screenshot 2024-07-18 at 10 42 41

Would this be the correct approach to do so on Llama model, and would it be of interest for pyvene?

@aryopg aryopg added the enhancement New feature or request label Jul 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant