From d40fff9e1fca0b34e903c2c23dbcecc4a8ca99f5 Mon Sep 17 00:00:00 2001 From: frankaging Date: Wed, 17 Jan 2024 01:15:24 -0800 Subject: [PATCH 1/3] Adding in constant source intervention support with new tests --- .../configuration_intervenable_model.py | 5 +- pyvene/models/intervenable_base.py | 82 +++++--- pyvene/models/intervention_utils.py | 21 ++ pyvene/models/interventions.py | 75 ++++++- pyvene/models/modeling_utils.py | 8 +- .../InterventionWithGPT2TestCase.py | 193 +++++++++++++++++- tests/utils.py | 6 +- 7 files changed, 339 insertions(+), 51 deletions(-) diff --git a/pyvene/models/configuration_intervenable_model.py b/pyvene/models/configuration_intervenable_model.py index 7ef697bd..629e16dd 100644 --- a/pyvene/models/configuration_intervenable_model.py +++ b/pyvene/models/configuration_intervenable_model.py @@ -13,8 +13,9 @@ "intervenable_layer intervenable_representation_type " "intervenable_unit max_number_of_units " "intervenable_low_rank_dimension " - "subspace_partition group_key intervention_link_key", - defaults=(0, "block_output", "pos", 1, None, None, None, None), + "subspace_partition group_key intervention_link_key intervenable_moe " + "source_representation", + defaults=(0, "block_output", "pos", 1, None, None, None, None, None, None), ) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index c66f9f87..2f1397b9 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -113,8 +113,10 @@ def __init__(self, intervenable_config, model, **kwargs): get_internal_model_type(model), model.config, representation ), proj_dim=representation.intervenable_low_rank_dimension, - # we can partition the subspace, and intervene on subspace + # additional args subspace_partition=representation.subspace_partition, + use_fast=self.use_fast, + source_representation=representation.source_representation, ) if representation.intervention_link_key in self._intervention_pointers: self._intervention_reverse_link[ @@ -129,9 +131,10 @@ def __init__(self, intervenable_config, model, **kwargs): get_internal_model_type(model), model.config, representation ), proj_dim=representation.intervenable_low_rank_dimension, - # we can partition the subspace, and intervene on subspace + # additional args subspace_partition=representation.subspace_partition, use_fast=self.use_fast, + source_representation=representation.source_representation, ) # we cache the intervention for sharing if the key is not None if representation.intervention_link_key is not None: @@ -803,8 +806,9 @@ def hook_callback(model, args, kwargs, output=None): if not self.is_model_stateless: selected_output = selected_output.clone() + if isinstance( - intervention, + intervention, CollectIntervention ): intervened_representation = do_intervention( @@ -820,16 +824,24 @@ def hook_callback(model, args, kwargs, output=None): # no-op to the output else: - intervened_representation = do_intervention( - selected_output, - self._reconcile_stateful_cached_activations( - key, + if intervention.is_source_constant: + intervened_representation = do_intervention( selected_output, - unit_locations_base[key_i], - ), - intervention, - subspaces[key_i] if subspaces is not None else None, - ) + None, + intervention, + subspaces[key_i] if subspaces is not None else None, + ) + else: + intervened_representation = do_intervention( + selected_output, + self._reconcile_stateful_cached_activations( + key, + selected_output, + unit_locations_base[key_i], + ), + intervention, + subspaces[key_i] if subspaces is not None else None, + ) # setter can produce hot activations for shared subspace interventions if linked if key in self._intervention_reverse_link: @@ -873,10 +885,10 @@ def _input_validation( ): """Fail fast input validation""" if self.mode == "parallel": - assert "sources->base" in unit_locations + assert "sources->base" in unit_locations or "base" in unit_locations elif activations_sources is None and self.mode == "serial": assert "sources->base" not in unit_locations - + # sources may contain None, but length should match if sources is not None: if len(sources) != len(self._intervention_group): @@ -982,10 +994,7 @@ def _wait_for_forward_with_parallel_intervention( for intervenable_key in intervenable_keys: # skip in case smart jump if intervenable_key in self.activations or \ - isinstance( - self.interventions[intervenable_key][0], - CollectIntervention - ): + self.interventions[intervenable_key][0].is_source_constant: set_handlers = self._intervention_setter( [intervenable_key], [ @@ -1054,10 +1063,7 @@ def _wait_for_forward_with_serial_intervention( for intervenable_key in intervenable_keys: # skip in case smart jump if intervenable_key in self.activations or \ - isinstance( - self.interventions[intervenable_key][0], - CollectIntervention - ): + self.interventions[intervenable_key][0].is_source_constant: # set with intervened activation to source_i+1 set_handlers = self._intervention_setter( [intervenable_key], @@ -1080,21 +1086,30 @@ def _broadcast_unit_locations( batch_size, unit_locations ): - _unit_locations = copy.deepcopy(unit_locations) + _unit_locations = {} for k, v in unit_locations.items(): + # special broadcast for base-only interventions + is_base_only = False + if k == "base": + is_base_only = True + k = "sources->base" if isinstance(v, int): _unit_locations[k] = ([[[v]]*batch_size], [[[v]]*batch_size]) self.use_fast = True - elif isinstance(v[0], int) and isinstance(v[1], int): + elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int): _unit_locations[k] = ([[[v[0]]]*batch_size], [[[v[1]]]*batch_size]) self.use_fast = True - elif isinstance(v[0], list) and isinstance(v[1], list): - pass # we don't support boardcase here yet. + elif len(v) == 2 and v[0] == None and isinstance(v[1], int): + _unit_locations[k] = (None, [[[v[1]]]*batch_size]) + self.use_fast = True + elif len(v) == 2 and isinstance(v[0], int) and v[1] == None: + _unit_locations[k] = ([[[v[0]]]*batch_size], None) + self.use_fast = True else: - raise ValueError( - f"unit_locations {unit_locations} contains invalid format." - ) - + if is_base_only: + _unit_locations[k] = (None, v) + else: + _unit_locations[k] = v return _unit_locations def forward( @@ -1173,12 +1188,15 @@ def forward( self._cleanup_states() # if no source inputs, we are calling a simple forward - if sources is None and activations_sources is None: + if sources is None and activations_sources is None \ + and unit_locations is None: return self.model(**base), None unit_locations = self._broadcast_unit_locations( get_batch_size(base), unit_locations) + sources = [None] if sources is None else sources + self._input_validation( base, sources, @@ -1287,6 +1305,8 @@ def generate( unit_locations = self._broadcast_unit_locations( get_batch_size(base), unit_locations) + sources = [None] if sources is None else None + self._input_validation( base, sources, diff --git a/pyvene/models/intervention_utils.py b/pyvene/models/intervention_utils.py index 8add3e40..0fd95b5d 100644 --- a/pyvene/models/intervention_utils.py +++ b/pyvene/models/intervention_utils.py @@ -37,7 +37,19 @@ def __repr__(self): def __str__(self): return json.dumps(self.state_dict, indent=4) +def broadcast_tensor(x, target_shape): + # Ensure the last dimension of target_shape matches x's size + if target_shape[-1] != x.shape[-1]: + raise ValueError("The last dimension of target_shape must match the size of x") + # Create a shape for reshaping x that is compatible with target_shape + reshape_shape = [1] * (len(target_shape) - 1) + [x.shape[-1]] + + # Reshape x and then broadcast it + x_reshaped = x.view(*reshape_shape) + broadcasted_x = x_reshaped.expand(*target_shape) + return broadcasted_x + def _do_intervention_by_swap( base, source, @@ -50,6 +62,15 @@ def _do_intervention_by_swap( """The basic do function that guards interventions""" if mode == "collect": assert source is None + # auto broadcast + if base.shape != source.shape: + try: + source = broadcast_tensor(source, base.shape) + except: + raise ValueError( + f"source with shape {source.shape} cannot be broadcasted " + f"into base with shape {base.shape}." + ) # interchange if use_fast: if subspaces is not None: diff --git a/pyvene/models/interventions.py b/pyvene/models/interventions.py index d4904315..bce5b4f6 100644 --- a/pyvene/models/interventions.py +++ b/pyvene/models/interventions.py @@ -14,7 +14,8 @@ def __init__(self, **kwargs): super().__init__() self.trainble = False self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False - + self.is_source_constant = False + @abstractmethod def set_interchange_dim(self, interchange_dim): pass @@ -31,11 +32,21 @@ class TrainableIntervention(Intervention): def __init__(self, **kwargs): super().__init__(**kwargs) self.trainble = True - + self.is_source_constant = False + def tie_weight(self, linked_intervention): pass +class ConstantSourceIntervention(Intervention): + + """Intervention the original representations.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.is_source_constant = True + + class BasisAgnosticIntervention(Intervention): """Intervention that will modify its basis in a uncontrolled manner.""" @@ -43,6 +54,7 @@ class BasisAgnosticIntervention(Intervention): def __init__(self, **kwargs): super().__init__(**kwargs) self.basis_agnostic = True + self.is_source_constant = False class SharedWeightsTrainableIntervention(TrainableIntervention): @@ -54,7 +66,37 @@ def __init__(self, **kwargs): self.shared_weights = True -class CollectIntervention(Intervention): +class ZeroIntervention(ConstantSourceIntervention): + + """Zero-out activations.""" + + def __init__(self, embed_dim, **kwargs): + super().__init__() + self.embed_dim = embed_dim + self.interchange_dim = embed_dim + self.subspace_partition = ( + kwargs["subspace_partition"] if "subspace_partition" in kwargs else None + ) + + def set_interchange_dim(self, interchange_dim): + self.interchange_dim = interchange_dim + + def forward(self, base, source=None, subspaces=None): + return _do_intervention_by_swap( + base, + torch.zeros_like(base), + "interchange", + self.interchange_dim, + subspaces, + subspace_partition=self.subspace_partition, + use_fast=self.use_fast, + ) + + def __str__(self): + return f"ZeroIntervention(embed_dim={self.embed_dim})" + + +class CollectIntervention(ConstantSourceIntervention): """Collect activations.""" @@ -125,14 +167,19 @@ def __init__(self, embed_dim, **kwargs): self.subspace_partition = ( kwargs["subspace_partition"] if "subspace_partition" in kwargs else None ) - + self.source_representation = ( + kwargs["source_representation"] if "source_representation" in kwargs else None + ) + if self.source_representation is not None: + self.is_source_constant = True + def set_interchange_dim(self, interchange_dim): self.interchange_dim = interchange_dim def forward(self, base, source, subspaces=None): return _do_intervention_by_swap( base, - source, + source if self.source_representation is None else self.source_representation, "interchange", self.interchange_dim, subspaces, @@ -155,14 +202,19 @@ def __init__(self, embed_dim, **kwargs): self.subspace_partition = ( kwargs["subspace_partition"] if "subspace_partition" in kwargs else None ) - + self.source_representation = ( + kwargs["source_representation"] if "source_representation" in kwargs else None + ) + if self.source_representation is not None: + self.is_source_constant = True + def set_interchange_dim(self, interchange_dim): self.interchange_dim = interchange_dim def forward(self, base, source, subspaces=None): return _do_intervention_by_swap( base, - source, + source if self.source_representation is None else self.source_representation, "add", self.interchange_dim, subspaces, @@ -185,14 +237,19 @@ def __init__(self, embed_dim, **kwargs): self.subspace_partition = ( kwargs["subspace_partition"] if "subspace_partition" in kwargs else None ) - + self.source_representation = ( + kwargs["source_representation"] if "source_representation" in kwargs else None + ) + if self.source_representation is not None: + self.is_source_constant = True + def set_interchange_dim(self, interchange_dim): self.interchange_dim = interchange_dim def forward(self, base, source, subspaces=None): return _do_intervention_by_swap( base, - source, + source if self.source_representation is None else self.source_representation, "subtract", self.interchange_dim, subspaces, diff --git a/pyvene/models/modeling_utils.py b/pyvene/models/modeling_utils.py index c8337e9e..a46d47cd 100644 --- a/pyvene/models/modeling_utils.py +++ b/pyvene/models/modeling_utils.py @@ -149,10 +149,14 @@ def get_intervenable_module_hook(model, representation) -> nn.Module: ] parameter_name = type_info[0] hook_type = type_info[1] - if "%s" in parameter_name: + if "%s" in parameter_name and representation.intervenable_moe is None: # we assume it is for the layer. parameter_name = parameter_name % (representation.intervenable_layer) - + else: + parameter_name = parameter_name % ( + int(representation.intervenable_layer), + int(representation.intervenable_moe) + ) module = getattr_for_torch_module(model, parameter_name) module_hook = getattr(module, hook_type) diff --git a/tests/integration_tests/InterventionWithGPT2TestCase.py b/tests/integration_tests/InterventionWithGPT2TestCase.py index fde40c50..fde18541 100644 --- a/tests/integration_tests/InterventionWithGPT2TestCase.py +++ b/tests/integration_tests/InterventionWithGPT2TestCase.py @@ -96,7 +96,7 @@ def test_invalid_intervenable_unit_negative(self): pass else: raise ValueError("ValueError for invalid intervenable unit is not thrown") - + def _test_with_position_intervention( self, intervention_layer, @@ -373,6 +373,190 @@ def test_with_location_broadcast_vanilla_intervention_positive(self): use_boardcast=True, ) + def _test_with_position_intervention_constant_source( + self, + intervention_layer, + intervention_stream, + intervention_type, + positions=[0], + use_base_only=False, + use_fast=False, + use_boardcast=False, + ): + max_position = np.max(np.array(positions)) + if isinstance(positions[0], list): + b_s = len(positions) + else: + b_s = 10 + base = { + "input_ids": torch.randint(0, 10, (b_s, max_position + 1)).to(self.device) + } + + intervenable_config = IntervenableConfig( + intervenable_model_type=type(self.gpt2), + intervenable_representations=[ + IntervenableRepresentationConfig( + intervention_layer, + intervention_stream, + "pos", + len(positions), + source_representation=torch.rand( + self.config.n_embd).to(self.gpt2.device) \ + if "mlp_activation" != intervention_stream else \ + torch.rand(self.config.n_embd*4).to(self.gpt2.device) + ) + ], + intervenable_interventions_type=intervention_type, + ) + intervenable = IntervenableModel( + intervenable_config, self.gpt2, use_fast=use_fast + ) + intervention = list(intervenable.interventions.values())[0][0] + + base_activations = {} + _ = GPT2_RUN(self.gpt2, base["input_ids"], base_activations, {}) + _key = f"{intervention_layer}.{intervention_stream}" + + for position in positions: + base_activations[_key][:, position] = intervention( + base_activations[_key][:, position], + None, + ) + + golden_out = GPT2_RUN( + self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]} + ) + + if use_base_only: + if use_boardcast: + _, out_output = intervenable( + base, + unit_locations={"base": positions[0]}, + ) + else: + _, out_output = intervenable( + base, + unit_locations={"base": ([[positions] * b_s])}, + ) + else: + if use_boardcast: + _, out_output = intervenable( + base, + unit_locations={"sources->base": (None, positions[0])}, + ) + else: + _, out_output = intervenable( + base, + unit_locations={"sources->base": (None, [[positions] * b_s])}, + ) + + self.assertTrue(torch.allclose(out_output[0], golden_out)) + + def test_with_position_intervention_constant_source_vanilla_intervention_positive(self): + """ + Enable constant source with vanilla intervention. + """ + for stream in self.nonhead_streams: + print( + f"testing constant source with stream: {stream} " + "with a single position with VanillaIntervention") + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=VanillaIntervention, + positions=[0], + ) + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=VanillaIntervention, + positions=[0], + use_base_only=True + ) + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=VanillaIntervention, + positions=[0], + use_base_only=True, + use_boardcast=True + ) + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=VanillaIntervention, + positions=[0], + use_base_only=True, + use_boardcast=True, + use_fast=True + ) + + def test_with_position_intervention_constant_source_addition_intervention_positive(self): + """ + Enable constant source with addition intervention. + """ + for stream in self.nonhead_streams: + print( + f"testing constant source with stream: {stream} " + "with a single position with AdditionIntervention") + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=AdditionIntervention, + positions=[0], + ) + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=AdditionIntervention, + positions=[0], + use_base_only=True + ) + + def test_with_position_intervention_constant_source_subtraction_intervention_positive(self): + """ + Enable constant source with subtraction intervention. + """ + for stream in self.nonhead_streams: + print( + f"testing constant source with stream: {stream} " + "with a single position with SubtractionIntervention") + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=SubtractionIntervention, + positions=[0], + ) + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=SubtractionIntervention, + positions=[0], + use_base_only=True + ) + + def test_with_position_intervention_constant_source_subtraction_intervention_positive(self): + """ + Enable constant source with subtraction intervention. + """ + for stream in self.nonhead_streams: + print( + f"testing constant source with stream: {stream} " + "with a single position with ZeroIntervention") + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=ZeroIntervention, + positions=[0], + ) + self._test_with_position_intervention_constant_source( + intervention_layer=random.randint(0, 3), + intervention_stream=stream, + intervention_type=ZeroIntervention, + positions=[0], + use_base_only=True + ) + def suite(): suite = unittest.TestSuite() suite.addTest(InterventionWithGPT2TestCase("test_clean_run_positive")) @@ -415,7 +599,12 @@ def suite(): InterventionWithGPT2TestCase( "test_with_location_broadcast_vanilla_intervention_positive" ) - ) + ) + suite.addTest( + InterventionWithGPT2TestCase( + "test_with_position_intervention_constant_source_vanilla_intervention_positive" + ) + ) return suite diff --git a/tests/utils.py b/tests/utils.py index e4b82c36..e6d82841 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -37,11 +37,7 @@ def is_package_installed(package_name): IntervenableConfig, ) from pyvene.models.intervenable_base import IntervenableModel -from pyvene.models.interventions import ( - VanillaIntervention, - RotatedSpaceIntervention, - LowRankRotatedSpaceIntervention, -) +from pyvene.models.interventions import * from pyvene.models.mlp.modelings_mlp import MLPConfig from pyvene.models.mlp.modelings_intervenable_mlp import create_mlp_classifier from pyvene.models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm From 769143a98df0390eb556f3cebbeb4a29374458db Mon Sep 17 00:00:00 2001 From: frankaging Date: Wed, 17 Jan 2024 02:35:44 -0800 Subject: [PATCH 2/3] avoid tensor squash for localist repr intervention --- pyvene/models/interventions.py | 42 +++++++++++++------ pyvene/models/modeling_utils.py | 7 +++- .../InterventionWithGPT2TestCase.py | 29 ++++++++++--- 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/pyvene/models/interventions.py b/pyvene/models/interventions.py index bce5b4f6..2ee9d8f8 100644 --- a/pyvene/models/interventions.py +++ b/pyvene/models/interventions.py @@ -40,13 +40,31 @@ def tie_weight(self, linked_intervention): class ConstantSourceIntervention(Intervention): - """Intervention the original representations.""" + """Constant source.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.is_source_constant = True - + +class LocalistRepresentationIntervention(torch.nn.Module): + + """Localist representation.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.is_repr_distributed = False + + +class DistributedRepresentationIntervention(torch.nn.Module): + + """Distributed representation.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.is_repr_distributed = True + + class BasisAgnosticIntervention(Intervention): """Intervention that will modify its basis in a uncontrolled manner.""" @@ -66,7 +84,7 @@ def __init__(self, **kwargs): self.shared_weights = True -class ZeroIntervention(ConstantSourceIntervention): +class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention): """Zero-out activations.""" @@ -126,7 +144,7 @@ def __str__(self): return f"CollectIntervention(embed_dim={self.embed_dim})" -class SkipIntervention(BasisAgnosticIntervention): +class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention): """Skip the current intervening layer's computation in the hook function.""" @@ -156,7 +174,7 @@ def __str__(self): return f"SkipIntervention(embed_dim={self.embed_dim})" -class VanillaIntervention(Intervention): +class VanillaIntervention(Intervention, LocalistRepresentationIntervention): """Intervention the original representations.""" @@ -191,7 +209,7 @@ def __str__(self): return f"VanillaIntervention(embed_dim={self.embed_dim})" -class AdditionIntervention(BasisAgnosticIntervention): +class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention): """Intervention the original representations with activation addition.""" @@ -226,7 +244,7 @@ def __str__(self): return f"AdditionIntervention(embed_dim={self.embed_dim})" -class SubtractionIntervention(BasisAgnosticIntervention): +class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention): """Intervention the original representations with activation subtraction.""" @@ -261,7 +279,7 @@ def __str__(self): return f"SubtractionIntervention(embed_dim={self.embed_dim})" -class RotatedSpaceIntervention(TrainableIntervention): +class RotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space.""" @@ -299,7 +317,7 @@ def __str__(self): return f"RotatedSpaceIntervention(embed_dim={self.embed_dim})" -class BoundlessRotatedSpaceIntervention(TrainableIntervention): +class BoundlessRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space with boundary mask.""" @@ -366,7 +384,7 @@ def __str__(self): return f"BoundlessRotatedSpaceIntervention(embed_dim={self.embed_dim})" -class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention): +class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space with boundary mask.""" @@ -420,7 +438,7 @@ def __str__(self): return f"SigmoidMaskRotatedSpaceIntervention(embed_dim={self.embed_dim})" -class LowRankRotatedSpaceIntervention(TrainableIntervention): +class LowRankRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space.""" @@ -503,7 +521,7 @@ def __str__(self): return f"LowRankRotatedSpaceIntervention(embed_dim={self.embed_dim})" -class PCARotatedSpaceIntervention(BasisAgnosticIntervention): +class PCARotatedSpaceIntervention(BasisAgnosticIntervention, DistributedRepresentationIntervention): """Intervention in the pca space.""" def __init__(self, embed_dim, **kwargs): diff --git a/pyvene/models/modeling_utils.py b/pyvene/models/modeling_utils.py index a46d47cd..9b692a8f 100644 --- a/pyvene/models/modeling_utils.py +++ b/pyvene/models/modeling_utils.py @@ -2,6 +2,7 @@ from torch import nn import numpy as np from .intervenable_modelcard import * +from .interventions import * def get_internal_model_type(model): @@ -517,7 +518,8 @@ def do_intervention( # flatten original_base_shape = base_representation.shape - if len(original_base_shape) == 2: + if len(original_base_shape) == 2 or \ + isinstance(intervention, LocalistRepresentationIntervention): # no pos dimension, e.g., gru base_representation_f = base_representation source_representation_f = source_representation @@ -537,7 +539,8 @@ def do_intervention( ) # unflatten - if len(original_base_shape) == 2: + if len(original_base_shape) == 2 or \ + isinstance(intervention, LocalistRepresentationIntervention): # no pos dimension, e.g., gru pass elif len(original_base_shape) == 3: diff --git a/tests/integration_tests/InterventionWithGPT2TestCase.py b/tests/integration_tests/InterventionWithGPT2TestCase.py index fde18541..ac380e14 100644 --- a/tests/integration_tests/InterventionWithGPT2TestCase.py +++ b/tests/integration_tests/InterventionWithGPT2TestCase.py @@ -418,10 +418,14 @@ def _test_with_position_intervention_constant_source( _key = f"{intervention_layer}.{intervention_stream}" for position in positions: - base_activations[_key][:, position] = intervention( - base_activations[_key][:, position], - None, - ) + if intervention_type == ZeroIntervention: + base_activations[_key][:, position] = torch.zeros_like( + base_activations[_key][:, position]) + else: + base_activations[_key][:, position] = intervention( + base_activations[_key][:, position], + None, + ) golden_out = GPT2_RUN( self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]} @@ -535,7 +539,7 @@ def test_with_position_intervention_constant_source_subtraction_intervention_pos use_base_only=True ) - def test_with_position_intervention_constant_source_subtraction_intervention_positive(self): + def test_with_position_intervention_constant_source_zero_intervention_positive(self): """ Enable constant source with subtraction intervention. """ @@ -605,6 +609,21 @@ def suite(): "test_with_position_intervention_constant_source_vanilla_intervention_positive" ) ) + suite.addTest( + InterventionWithGPT2TestCase( + "test_with_position_intervention_constant_source_addition_intervention_positive" + ) + ) + suite.addTest( + InterventionWithGPT2TestCase( + "test_with_position_intervention_constant_source_subtraction_intervention_positive" + ) + ) + suite.addTest( + InterventionWithGPT2TestCase( + "test_with_position_intervention_constant_source_zero_intervention_positive" + ) + ) return suite From bf1a13243dbf9656b17964b8563d54b9e484256e Mon Sep 17 00:00:00 2001 From: frankaging Date: Wed, 17 Jan 2024 03:23:38 -0800 Subject: [PATCH 3/3] Adding additional tests for long whole seq intervention --- .../InterventionWithGPT2TestCase.py | 82 ++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/InterventionWithGPT2TestCase.py b/tests/integration_tests/InterventionWithGPT2TestCase.py index ac380e14..98536565 100644 --- a/tests/integration_tests/InterventionWithGPT2TestCase.py +++ b/tests/integration_tests/InterventionWithGPT2TestCase.py @@ -16,7 +16,7 @@ def setUpClass(self): n_layer=4, bos_token_id=0, eos_token_id=0, - n_positions=128, + n_positions=1024, vocab_size=10, ) ) @@ -561,6 +561,81 @@ def test_with_position_intervention_constant_source_zero_intervention_positive(s use_base_only=True ) + def _test_with_long_sequence_position_intervention_constant_source_positive( + self, intervention_stream, intervention_type): + b_s = 10 + max_position = 512 + positions = [_ for _ in range(max_position)] + base = { + "input_ids": torch.randint(0, 10, (b_s, max_position)).to(self.device) + } + intervention_layer = random.randint(0, 2) + + intervenable_config = IntervenableConfig( + intervenable_model_type=type(self.gpt2), + intervenable_representations=[ + IntervenableRepresentationConfig( + intervention_layer, + intervention_stream, + "pos", + len(positions), + source_representation=torch.rand( + self.config.n_embd).to(self.gpt2.device) \ + if "mlp_activation" != intervention_stream else \ + torch.rand(self.config.n_embd*4).to(self.gpt2.device) + ) + ], + intervenable_interventions_type=intervention_type, + ) + intervenable = IntervenableModel( + intervenable_config, self.gpt2, use_fast=True + ) + intervention = list(intervenable.interventions.values())[0][0] + + base_activations = {} + _ = GPT2_RUN(self.gpt2, base["input_ids"], base_activations, {}) + _key = f"{intervention_layer}.{intervention_stream}" + + for position in positions: + if intervention_type == ZeroIntervention: + base_activations[_key] = torch.zeros_like( + base_activations[_key]) + else: + base_activations[_key] = intervention( + base_activations[_key], + None, + ) + + golden_out = GPT2_RUN( + self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]} + ) + + _, out_output = intervenable( + base, + unit_locations={"base": ([[positions] * b_s])}, + ) + + self.assertTrue(torch.allclose(out_output[0], golden_out)) + + def test_with_long_sequence_position_intervention_constant_source_positive(self): + for stream in self.nonhead_streams: + print( + f"testing constant source with stream: {stream} " + "with long sequence multiple position with VanillaIntervention") + self._test_with_long_sequence_position_intervention_constant_source_positive( + intervention_stream=stream, + intervention_type=VanillaIntervention, + ) + for stream in self.nonhead_streams: + print( + f"testing constant source with stream: {stream} " + "with long sequence multiple position with ZeroIntervention") + self._test_with_long_sequence_position_intervention_constant_source_positive( + intervention_stream=stream, + intervention_type=ZeroIntervention, + ) + + def suite(): suite = unittest.TestSuite() suite.addTest(InterventionWithGPT2TestCase("test_clean_run_positive")) @@ -624,6 +699,11 @@ def suite(): "test_with_position_intervention_constant_source_zero_intervention_positive" ) ) + suite.addTest( + InterventionWithGPT2TestCase( + "_test_with_long_sequence_position_intervention_constant_source_positive" + ) + ) return suite