Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add intervenable_model to forward's function signature #191

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def _intervention_setter(
keys,
unit_locations_base,
subspaces,
**intervention_forward_kwargs
) -> HandlerList:
"""
Create a list of setter tracer that will set activations
Expand Down Expand Up @@ -839,6 +840,7 @@ def _intervention_setter(
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
# fail if this is not a fresh collect
assert key not in self.activations
Expand All @@ -853,6 +855,7 @@ def _intervention_setter(
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
else:
intervened_representation = do_intervention(
Expand All @@ -864,6 +867,7 @@ def _intervention_setter(
),
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
else:
# highly unlikely it's a primitive intervention type
Expand All @@ -876,6 +880,7 @@ def _intervention_setter(
),
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
if intervened_representation is None:
return
Expand Down Expand Up @@ -961,6 +966,7 @@ def _sync_forward_with_parallel_intervention(
]
if subspaces is not None
else None,
**kwargs
)
counterfactual_outputs = self.model.output.save()

Expand Down Expand Up @@ -988,6 +994,7 @@ def forward(
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs
):
activations_sources = source_representations
if sources is not None and not isinstance(sources, list):
Expand Down Expand Up @@ -1027,7 +1034,7 @@ def forward(
try:

# run intervened forward
model_kwargs = {}
model_kwargs = { **kwargs }
if labels is not None: # for training
model_kwargs["labels"] = labels
if use_cache is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models
Expand Down Expand Up @@ -1507,6 +1514,7 @@ def _intervention_setter(
keys,
unit_locations_base,
subspaces,
**intervention_forward_kwargs
) -> HandlerList:
"""
Create a list of setter handlers that will set activations
Expand Down Expand Up @@ -1553,6 +1561,7 @@ def hook_callback(model, args, kwargs, output=None):
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
# fail if this is not a fresh collect
assert key not in self.activations
Expand All @@ -1568,6 +1577,7 @@ def hook_callback(model, args, kwargs, output=None):
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
if isinstance(raw_intervened_representation, InterventionOutput):
self.full_intervention_outputs.append(raw_intervened_representation)
Expand All @@ -1584,6 +1594,7 @@ def hook_callback(model, args, kwargs, output=None):
),
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
else:
# highly unlikely it's a primitive intervention type
Expand All @@ -1596,6 +1607,7 @@ def hook_callback(model, args, kwargs, output=None):
),
intervention,
subspaces[key_i] if subspaces is not None else None,
**intervention_forward_kwargs
)
if intervened_representation is None:
return
Expand Down Expand Up @@ -1663,6 +1675,7 @@ def _wait_for_forward_with_parallel_intervention(
unit_locations,
activations_sources: Optional[Dict] = None,
subspaces: Optional[List] = None,
**intervention_forward_kwargs
):
# torch.autograd.set_detect_anomaly(True)
all_set_handlers = HandlerList([])
Expand Down Expand Up @@ -1718,6 +1731,7 @@ def _wait_for_forward_with_parallel_intervention(
]
if subspaces is not None
else None,
**intervention_forward_kwargs
)
# for setters, we don't remove them.
all_set_handlers.extend(set_handlers)
Expand All @@ -1729,6 +1743,7 @@ def _wait_for_forward_with_serial_intervention(
unit_locations,
activations_sources: Optional[Dict] = None,
subspaces: Optional[List] = None,
**intervention_forward_kwargs
):
all_set_handlers = HandlerList([])
for group_id, keys in self._intervention_group.items():
Expand Down Expand Up @@ -1785,6 +1800,7 @@ def _wait_for_forward_with_serial_intervention(
]
if subspaces is not None
else None,
**intervention_forward_kwargs
)
# for setters, we don't remove them.
all_set_handlers.extend(set_handlers)
Expand All @@ -1801,6 +1817,7 @@ def forward(
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**intervention_forward_kwargs
):
"""
Main forward function that serves a wrapper to
Expand Down Expand Up @@ -1909,6 +1926,7 @@ def forward(
unit_locations,
activations_sources,
subspaces,
**intervention_forward_kwargs
)
)
elif self.mode == "serial":
Expand All @@ -1918,6 +1936,7 @@ def forward(
unit_locations,
activations_sources,
subspaces,
**intervention_forward_kwargs
)
)

Expand Down Expand Up @@ -2051,6 +2070,7 @@ def generate(
unit_locations,
activations_sources,
subspaces,
**kwargs
)
)
elif self.mode == "serial":
Expand All @@ -2060,6 +2080,7 @@ def generate(
unit_locations,
activations_sources,
subspaces,
**kwargs
)
)

Expand Down
32 changes: 16 additions & 16 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def set_interchange_dim(self, interchange_dim):
self.interchange_dim = interchange_dim

@abstractmethod
def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
pass


Expand Down Expand Up @@ -153,7 +153,7 @@ class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationInterve
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source=None, subspaces=None):
def forward(self, base, source=None, subspaces=None, **kwargs):
return _do_intervention_by_swap(
base,
torch.zeros_like(base),
Expand All @@ -175,7 +175,7 @@ class CollectIntervention(ConstantSourceIntervention):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source=None, subspaces=None):
def forward(self, base, source=None, subspaces=None, **kwargs):
return _do_intervention_by_swap(
base,
source,
Expand All @@ -197,7 +197,7 @@ class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationInterven
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
# source here is the base example input to the hook
return _do_intervention_by_swap(
base,
Expand All @@ -220,7 +220,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
return _do_intervention_by_swap(
base,
source if self.source_representation is None else self.source_representation,
Expand All @@ -242,7 +242,7 @@ class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationInte
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
return _do_intervention_by_swap(
base,
source if self.source_representation is None else self.source_representation,
Expand All @@ -264,7 +264,7 @@ class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationI
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):

return _do_intervention_by_swap(
base,
Expand All @@ -289,7 +289,7 @@ def __init__(self, **kwargs):
rotate_layer = RotateLayer(self.embed_dim)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
rotated_base = self.rotate_layer(base)
rotated_source = self.rotate_layer(source)
# interchange
Expand Down Expand Up @@ -340,7 +340,7 @@ def set_intervention_boundaries(self, intervention_boundaries):
torch.tensor([intervention_boundaries]), requires_grad=True
)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
batch_size = base.shape[0]
rotated_base = self.rotate_layer(base)
rotated_source = self.rotate_layer(source)
Expand Down Expand Up @@ -391,7 +391,7 @@ def get_temperature(self):
def set_temperature(self, temp: torch.Tensor):
self.temperature.data = temp

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
batch_size = base.shape[0]
rotated_base = self.rotate_layer(base)
rotated_source = self.rotate_layer(source)
Expand Down Expand Up @@ -431,7 +431,7 @@ def get_temperature(self):
def set_temperature(self, temp: torch.Tensor):
self.temperature.data = temp

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
batch_size = base.shape[0]
# get boundary mask between 0 and 1 from sigmoid
mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
Expand All @@ -456,7 +456,7 @@ def __init__(self, **kwargs):
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
rotated_base = self.rotate_layer(base)
rotated_source = self.rotate_layer(source)
if subspaces is not None:
Expand Down Expand Up @@ -529,7 +529,7 @@ def __init__(self, **kwargs):
)
self.trainable = False

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
base_norm = (base - self.pca_mean) / self.pca_std
source_norm = (source - self.pca_mean) / self.pca_std

Expand Down Expand Up @@ -565,7 +565,7 @@ def __init__(self, **kwargs):
prng(1, 4, self.embed_dim)))
self.register_buffer('noise_level', torch.tensor(noise_level))

def forward(self, base, source=None, subspaces=None):
def forward(self, base, source=None, subspaces=None, **kwargs):
base[..., : self.interchange_dim] += self.noise * self.noise_level
return base

Expand All @@ -585,7 +585,7 @@ def __init__(self, **kwargs):
self.autoencoder = AutoencoderLayer(
self.embed_dim, kwargs["latent_dim"])

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None, **kwargs):
base_dtype = base.dtype
base = base.to(self.autoencoder.encoder[0].weight.dtype)
base_latent = self.autoencoder.encode(base)
Expand Down Expand Up @@ -619,7 +619,7 @@ def encode(self, input_acts):
def decode(self, acts):
return acts @ self.W_dec + self.b_dec

def forward(self, base, source=None, subspaces=None):
def forward(self, base, source=None, subspaces=None, **kwargs):
# generate latents for base and source runs.
base_latent = self.encode(base)
source_latent = self.encode(source)
Expand Down
5 changes: 3 additions & 2 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def scatter_neurons(


def do_intervention(
base_representation, source_representation, intervention, subspaces
base_representation, source_representation, intervention, subspaces, **intervention_forward_kwargs
):
"""Do the actual intervention."""

Expand Down Expand Up @@ -463,7 +463,8 @@ def do_intervention(
assert False # what's going on?

intervention_output = intervention(
base_representation_f, source_representation_f, subspaces
base_representation_f, source_representation_f, subspaces,
**intervention_forward_kwargs
)
if isinstance(intervention_output, InterventionOutput):
intervened_representation = intervention_output.output
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/IntervenableBasicTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class MultiplierIntervention(
def __init__(self, embed_dim, **kwargs):
super().__init__()
def forward(
self, base, source=None, subspaces=None):
self, base, source=None, subspaces=None, **kwargs):
return base * 99.0
# run with new intervention type
pv_gpt2 = pv.IntervenableModel({
Expand Down
26 changes: 26 additions & 0 deletions tests/integration_tests/InterventionWithLlamaTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,32 @@ def test_with_multiple_heads_positions_vanilla_intervention_positive(self):
heads=[4, 1],
positions=[7, 2],
)

def test_with_llm_head(self):
that = self
_lm_head_collection = {}
class AccessIntervenableModelIntervention:
is_source_constant = True
keep_last_dim = True
intervention_types = 'access_intervenable_model_intervention'
def __init__(self, layer_index, *args, **kwargs):
super().__init__()
self.layer_index = layer_index
def __call__(self, base, source=None, subspaces=None, model=None, **kwargs):
intervenable_model = kwargs.get('intervenable_model', None)
assert intervenable_model is not None
_lm_head_collection[self.layer_index] = intervenable_model.model.lm_head(base.to(that.device))
return base
# run with new intervention type
pv_llama = IntervenableModel([{
"intervention": AccessIntervenableModelIntervention(layer_index=layer),
"component": f"model.layers.{layer}.input"
} for layer in [1, 3]], model=self.llama)
intervened_outputs = pv_llama(
base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device),
unit_locations={"base": 3},
intervenable_model=pv_llama
)


def suite():
Expand Down