Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add intervenable_model to forward's function signature #191

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

eggachecat
Copy link

Description

Added the intervenable_model parameter to the forward function signature, enabling user-defined intervention classes to have direct access to the model instance. This allows for more advanced manipulations such as using intervenable_model.model.lm_head(base) to interact with lower-level model components.

Testing Done

Tested the changes locally by defining custom intervention classes that access the model's internal components, including the lm_head. Verified that these interventions function as expected during model execution.

Checklist:

  • My PR title strictly follows the format: [Your Priority] Your Title
  • I have attached the testing log above
  • I provide enough comments to my code
  • I have changed documentations
  • I have added tests for my changes

Enable user-defined intervention classes to access the model.
This allows users to interact with the model more flexibly by using constructs like `model.model.lm_head(base)`.
@frankaging
Copy link
Collaborator

frankaging commented Oct 9, 2024

Thanks for the change! The use case seems to be useful.

One general comment, could you turn the intervention signature into a more generic version using kwargs?

def forward(self, base, source, subspaces=None, **kwargs):
...

And the caller in these setter function should also take in **kwargs from user, if it is passed, then set.

The intervenable model forward call thus can take in arguments such as,

pv_model.forward(base=..., sources=[...], intervenable_model=pv_model)

Let me know if this makes sense! If you could make this change, it would be great since it will support many use cases.

@eggachecat
Copy link
Author

Hi @frankaging , I think it totally makes sense and have updated the code accordingly. Please take a look when you have a chance. Thanks!

Copy link
Collaborator

@frankaging frankaging left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! Otherwise, the change LGTM!

@@ -839,6 +839,7 @@ def _intervention_setter(
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
intervenable_model=self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also change this to passing **kwargs?

@@ -431,7 +431,7 @@ def scatter_neurons(


def do_intervention(
base_representation, source_representation, intervention, subspaces
base_representation, source_representation, intervention, subspaces, intervenable_model=None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, could you also change this to passing **kwargs?

} for layer in [1, 3]], model=self.llama)
intervened_outputs = pv_llama(
base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device),
unit_locations={"base": 3}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the changes, now in this line, you could pass in your customized field to the model, such as self-referencing.

@eggachecat
Copy link
Author

hi @frankaging I just made another PR

Now the function signature for users looks like

    def test_with_llm_head(self):
        that = self
        _lm_head_collection = {}
        class AccessIntervenableModelIntervention:
            is_source_constant = True
            keep_last_dim = True
            intervention_types = 'access_intervenable_model_intervention'
            def __init__(self, layer_index, *args, **kwargs):
                super().__init__()
                self.layer_index = layer_index
            def __call__(self, base, source=None, subspaces=None, model=None, **kwargs):
                intervenable_model = kwargs.get('intervenable_model', None)
                assert intervenable_model is not None
                _lm_head_collection[self.layer_index] = intervenable_model.model.lm_head(base.to(that.device))
                return base
        # run with new intervention type
        pv_llama = IntervenableModel([{
            "intervention": AccessIntervenableModelIntervention(layer_index=layer),
            "component": f"model.layers.{layer}.input"
        } for layer in [1, 3]], model=self.llama)
        intervened_outputs = pv_llama(
            base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device), 
            unit_locations={"base": 3},
            # anything passed here will be forwarded to the __call__
            intervenable_model=pv_llama 
        )

@eggachecat
Copy link
Author

👀👀👀👀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants