-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add intervenable_model to forward's function signature #191
base: main
Are you sure you want to change the base?
Conversation
Enable user-defined intervention classes to access the model. This allows users to interact with the model more flexibly by using constructs like `model.model.lm_head(base)`.
Thanks for the change! The use case seems to be useful. One general comment, could you turn the intervention signature into a more generic version using kwargs?
And the caller in these setter function should also take in **kwargs from user, if it is passed, then set. The intervenable model forward call thus can take in arguments such as,
Let me know if this makes sense! If you could make this change, it would be great since it will support many use cases. |
Hi @frankaging , I think it totally makes sense and have updated the code accordingly. Please take a look when you have a chance. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! Otherwise, the change LGTM!
pyvene/models/intervenable_base.py
Outdated
@@ -839,6 +839,7 @@ def _intervention_setter( | |||
None, | |||
intervention, | |||
subspaces[key_i] if subspaces is not None else None, | |||
intervenable_model=self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also change this to passing **kwargs?
pyvene/models/modeling_utils.py
Outdated
@@ -431,7 +431,7 @@ def scatter_neurons( | |||
|
|||
|
|||
def do_intervention( | |||
base_representation, source_representation, intervention, subspaces | |||
base_representation, source_representation, intervention, subspaces, intervenable_model=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, could you also change this to passing **kwargs?
} for layer in [1, 3]], model=self.llama) | ||
intervened_outputs = pv_llama( | ||
base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device), | ||
unit_locations={"base": 3} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the changes, now in this line, you could pass in your customized field to the model, such as self-referencing.
hi @frankaging I just made another PR Now the function signature for users looks like def test_with_llm_head(self):
that = self
_lm_head_collection = {}
class AccessIntervenableModelIntervention:
is_source_constant = True
keep_last_dim = True
intervention_types = 'access_intervenable_model_intervention'
def __init__(self, layer_index, *args, **kwargs):
super().__init__()
self.layer_index = layer_index
def __call__(self, base, source=None, subspaces=None, model=None, **kwargs):
intervenable_model = kwargs.get('intervenable_model', None)
assert intervenable_model is not None
_lm_head_collection[self.layer_index] = intervenable_model.model.lm_head(base.to(that.device))
return base
# run with new intervention type
pv_llama = IntervenableModel([{
"intervention": AccessIntervenableModelIntervention(layer_index=layer),
"component": f"model.layers.{layer}.input"
} for layer in [1, 3]], model=self.llama)
intervened_outputs = pv_llama(
base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device),
unit_locations={"base": 3},
# anything passed here will be forwarded to the __call__
intervenable_model=pv_llama
) |
👀👀👀👀 |
Description
Added the
intervenable_model
parameter to the forward function signature, enabling user-defined intervention classes to have direct access to the model instance. This allows for more advanced manipulations such as usingintervenable_model.model.lm_head(base)
to interact with lower-level model components.Testing Done
Tested the changes locally by defining custom intervention classes that access the model's internal components, including the
lm_head
. Verified that these interventions function as expected during model execution.Checklist:
[Your Priority] Your Title