Skip to content

Commit

Permalink
Merge pull request #59 from stanfordnlp/zen/constantsourceint
Browse files Browse the repository at this point in the history
[P1] Adding in constant source intervention support with new tests
  • Loading branch information
frankaging authored Jan 17, 2024
2 parents f46ab90 + bf1a132 commit 5894b18
Show file tree
Hide file tree
Showing 7 changed files with 471 additions and 63 deletions.
5 changes: 3 additions & 2 deletions pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"intervenable_layer intervenable_representation_type "
"intervenable_unit max_number_of_units "
"intervenable_low_rank_dimension "
"subspace_partition group_key intervention_link_key",
defaults=(0, "block_output", "pos", 1, None, None, None, None),
"subspace_partition group_key intervention_link_key intervenable_moe "
"source_representation",
defaults=(0, "block_output", "pos", 1, None, None, None, None, None, None),
)


Expand Down
82 changes: 51 additions & 31 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def __init__(self, intervenable_config, model, **kwargs):
get_internal_model_type(model), model.config, representation
),
proj_dim=representation.intervenable_low_rank_dimension,
# we can partition the subspace, and intervene on subspace
# additional args
subspace_partition=representation.subspace_partition,
use_fast=self.use_fast,
source_representation=representation.source_representation,
)
if representation.intervention_link_key in self._intervention_pointers:
self._intervention_reverse_link[
Expand All @@ -129,9 +131,10 @@ def __init__(self, intervenable_config, model, **kwargs):
get_internal_model_type(model), model.config, representation
),
proj_dim=representation.intervenable_low_rank_dimension,
# we can partition the subspace, and intervene on subspace
# additional args
subspace_partition=representation.subspace_partition,
use_fast=self.use_fast,
source_representation=representation.source_representation,
)
# we cache the intervention for sharing if the key is not None
if representation.intervention_link_key is not None:
Expand Down Expand Up @@ -803,8 +806,9 @@ def hook_callback(model, args, kwargs, output=None):
if not self.is_model_stateless:
selected_output = selected_output.clone()


if isinstance(
intervention,
intervention,
CollectIntervention
):
intervened_representation = do_intervention(
Expand All @@ -820,16 +824,24 @@ def hook_callback(model, args, kwargs, output=None):
# no-op to the output

else:
intervened_representation = do_intervention(
selected_output,
self._reconcile_stateful_cached_activations(
key,
if intervention.is_source_constant:
intervened_representation = do_intervention(
selected_output,
unit_locations_base[key_i],
),
intervention,
subspaces[key_i] if subspaces is not None else None,
)
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
)
else:
intervened_representation = do_intervention(
selected_output,
self._reconcile_stateful_cached_activations(
key,
selected_output,
unit_locations_base[key_i],
),
intervention,
subspaces[key_i] if subspaces is not None else None,
)

# setter can produce hot activations for shared subspace interventions if linked
if key in self._intervention_reverse_link:
Expand Down Expand Up @@ -873,10 +885,10 @@ def _input_validation(
):
"""Fail fast input validation"""
if self.mode == "parallel":
assert "sources->base" in unit_locations
assert "sources->base" in unit_locations or "base" in unit_locations
elif activations_sources is None and self.mode == "serial":
assert "sources->base" not in unit_locations

# sources may contain None, but length should match
if sources is not None:
if len(sources) != len(self._intervention_group):
Expand Down Expand Up @@ -982,10 +994,7 @@ def _wait_for_forward_with_parallel_intervention(
for intervenable_key in intervenable_keys:
# skip in case smart jump
if intervenable_key in self.activations or \
isinstance(
self.interventions[intervenable_key][0],
CollectIntervention
):
self.interventions[intervenable_key][0].is_source_constant:
set_handlers = self._intervention_setter(
[intervenable_key],
[
Expand Down Expand Up @@ -1054,10 +1063,7 @@ def _wait_for_forward_with_serial_intervention(
for intervenable_key in intervenable_keys:
# skip in case smart jump
if intervenable_key in self.activations or \
isinstance(
self.interventions[intervenable_key][0],
CollectIntervention
):
self.interventions[intervenable_key][0].is_source_constant:
# set with intervened activation to source_i+1
set_handlers = self._intervention_setter(
[intervenable_key],
Expand All @@ -1080,21 +1086,30 @@ def _broadcast_unit_locations(
batch_size,
unit_locations
):
_unit_locations = copy.deepcopy(unit_locations)
_unit_locations = {}
for k, v in unit_locations.items():
# special broadcast for base-only interventions
is_base_only = False
if k == "base":
is_base_only = True
k = "sources->base"
if isinstance(v, int):
_unit_locations[k] = ([[[v]]*batch_size], [[[v]]*batch_size])
self.use_fast = True
elif isinstance(v[0], int) and isinstance(v[1], int):
elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int):
_unit_locations[k] = ([[[v[0]]]*batch_size], [[[v[1]]]*batch_size])
self.use_fast = True
elif isinstance(v[0], list) and isinstance(v[1], list):
pass # we don't support boardcase here yet.
elif len(v) == 2 and v[0] == None and isinstance(v[1], int):
_unit_locations[k] = (None, [[[v[1]]]*batch_size])
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and v[1] == None:
_unit_locations[k] = ([[[v[0]]]*batch_size], None)
self.use_fast = True
else:
raise ValueError(
f"unit_locations {unit_locations} contains invalid format."
)

if is_base_only:
_unit_locations[k] = (None, v)
else:
_unit_locations[k] = v
return _unit_locations

def forward(
Expand Down Expand Up @@ -1173,12 +1188,15 @@ def forward(
self._cleanup_states()

# if no source inputs, we are calling a simple forward
if sources is None and activations_sources is None:
if sources is None and activations_sources is None \
and unit_locations is None:
return self.model(**base), None

unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)

sources = [None] if sources is None else sources

self._input_validation(
base,
sources,
Expand Down Expand Up @@ -1287,6 +1305,8 @@ def generate(
unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)

sources = [None] if sources is None else None

self._input_validation(
base,
sources,
Expand Down
21 changes: 21 additions & 0 deletions pyvene/models/intervention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,19 @@ def __repr__(self):
def __str__(self):
return json.dumps(self.state_dict, indent=4)

def broadcast_tensor(x, target_shape):
# Ensure the last dimension of target_shape matches x's size
if target_shape[-1] != x.shape[-1]:
raise ValueError("The last dimension of target_shape must match the size of x")

# Create a shape for reshaping x that is compatible with target_shape
reshape_shape = [1] * (len(target_shape) - 1) + [x.shape[-1]]

# Reshape x and then broadcast it
x_reshaped = x.view(*reshape_shape)
broadcasted_x = x_reshaped.expand(*target_shape)
return broadcasted_x

def _do_intervention_by_swap(
base,
source,
Expand All @@ -50,6 +62,15 @@ def _do_intervention_by_swap(
"""The basic do function that guards interventions"""
if mode == "collect":
assert source is None
# auto broadcast
if base.shape != source.shape:
try:
source = broadcast_tensor(source, base.shape)
except:
raise ValueError(
f"source with shape {source.shape} cannot be broadcasted "
f"into base with shape {base.shape}."
)
# interchange
if use_fast:
if subspaces is not None:
Expand Down
Loading

0 comments on commit 5894b18

Please sign in to comment.