Skip to content

Commit

Permalink
Merge pull request #92 from stanfordnlp/peterwz
Browse files Browse the repository at this point in the history
Model utils tests for GPT-2
  • Loading branch information
frankaging authored Jan 26, 2024
2 parents 8fa97fa + 639835b commit 91de11f
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 76 deletions.
150 changes: 83 additions & 67 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .interventions import *
from .constants import *


def get_internal_model_type(model):
"""Return the model type."""
return type(model)
Expand Down Expand Up @@ -93,11 +94,9 @@ def getattr_for_torch_module(model, parameter_name):
return current_module


def get_dimension_by_component(
model_type, model_config, component
) -> int:
def get_dimension_by_component(model_type, model_config, component) -> int:
"""Based on the representation, get the aligning dimension size."""

if component not in type_to_dimension_mapping[model_type]:
return None

Expand All @@ -111,15 +110,15 @@ def get_dimension_by_component(
elif "/" in proposal:
# often split by head number
if proposal.split("/")[0].isnumeric():
numr = int(proposal.split("/")[0])
numr = int(proposal.split("/")[0])
else:
numr = getattr_for_torch_module(model_config, proposal.split("/")[0])

if proposal.split("/")[1].isnumeric():
denr = int(proposal.split("/")[1])
denr = int(proposal.split("/")[1])
else:
denr = getattr_for_torch_module(model_config, proposal.split("/")[1])
dimension = int(numr/denr)
dimension = int(numr / denr)
else:
dimension = getattr_for_torch_module(model_config, proposal)
if dimension is not None:
Expand All @@ -130,8 +129,10 @@ def get_dimension_by_component(

def get_module_hook(model, representation) -> nn.Module:
"""Render the intervening module with a hook."""
if representation.component in type_to_module_mapping[
get_internal_model_type(model)]:
if (
representation.component
in type_to_module_mapping[get_internal_model_type(model)]
):
type_info = type_to_module_mapping[get_internal_model_type(model)][
representation.component
]
Expand Down Expand Up @@ -229,18 +230,12 @@ def output_to_subcomponent(output, component, model_type, model_config):
raise ValueError(f"Unsupported {split_last_dim_by}.")
for i, (split_fn, param) in enumerate(split_last_dim_by):
if isinstance(param, str):
param = get_dimension_by_component(
model_type, model_config, param
)
param = get_dimension_by_component(model_type, model_config, param)
subcomponent = split_fn(subcomponent, param)
return subcomponent


def gather_neurons(
tensor_input,
unit,
unit_locations_as_list
):
def gather_neurons(tensor_input, unit, unit_locations_as_list):
"""Gather intervening neurons.
:param tensor_input: tensors of shape (batch_size, sequence_length, ...) if
Expand All @@ -256,36 +251,37 @@ def gather_neurons(
"""
if unit in {"t"}:
return tensor_input

if "." in unit:
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
torch.tensor(unit_locations_as_list[1], device=tensor_input.device),
)
# we assume unit_locations is a tuple
head_unit_locations = unit_locations[0]
pos_unit_locations = unit_locations[1]

head_tensor_output = torch.gather(
tensor_input,
1,
head_unit_locations.reshape(
*head_unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2)
).expand(-1, -1, *tensor_input.shape[2:]),
) # b, h, s, d
d = head_tensor_output.shape[1]
pos_tensor_input = bhsd_to_bs_hd(head_tensor_output)
pos_tensor_output = torch.gather(
pos_tensor_input,
1,
pos_unit_locations.reshape(
*pos_unit_locations.shape, *(1,) * (len(pos_tensor_input.shape) - 2)
).expand(-1, -1, *pos_tensor_input.shape[2:]),
) # b, num_unit (pos), num_unit (h)*d
tensor_output = bs_hd_to_bhsd(pos_tensor_output, d)

return tensor_output # b, num_unit (h), num_unit (pos), d
else:
if "." in unit:
if unit in {"h.pos"}:
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
torch.tensor(unit_locations_as_list[1], device=tensor_input.device),
)
# we assume unit_locations is a tuple
head_unit_locations = unit_locations[0]
pos_unit_locations = unit_locations[1]

head_tensor_output = torch.gather(
tensor_input,
1,
head_unit_locations.reshape(
*head_unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2)
).expand(-1, -1, *tensor_input.shape[2:]),
) # b, h, s, d
d = head_tensor_output.shape[1]
pos_tensor_input = bhsd_to_bs_hd(head_tensor_output)
pos_tensor_output = torch.gather(
pos_tensor_input,
1,
pos_unit_locations.reshape(
*pos_unit_locations.shape, *(1,) * (len(pos_tensor_input.shape) - 2)
).expand(-1, -1, *pos_tensor_input.shape[2:]),
) # b, num_unit (pos), num_unit (h)*d
tensor_output = bs_hd_to_bhsd(pos_tensor_output, d)

return tensor_output # b, num_unit (h), num_unit (pos), d
elif unit in {"h", "pos"}:
unit_locations = torch.tensor(
unit_locations_as_list, device=tensor_input.device
)
Expand Down Expand Up @@ -343,17 +339,21 @@ def scatter_neurons(
unit_locations = torch.tensor(
unit_locations_as_list, device=tensor_input.device
)

# if tensor is splitted, we need to get the start and end indices
meta_component = output_to_subcomponent(
torch.arange(tensor_input.shape[-1]).unsqueeze(dim=0).unsqueeze(dim=0),
component, model_type, model_config
component,
model_type,
model_config,
)
start_index, end_index = (
meta_component.min().tolist(),
meta_component.max().tolist() + 1,
)
start_index, end_index = \
meta_component.min().tolist(), meta_component.max().tolist()+1
last_dim = meta_component.shape[-1]
_batch_idx = torch.arange(tensor_input.shape[0]).unsqueeze(1)

# in case it is time step, there is no sequence-related index
if unit in {"t"}:
# time series models, e.g., gru
Expand All @@ -362,39 +362,55 @@ def scatter_neurons(
elif unit in {"pos"}:
if use_fast:
# maybe this is all redundant, but maybe faster slightly?
tensor_input[_batch_idx, unit_locations[0], start_index:end_index] = replacing_tensor_input
tensor_input[
_batch_idx, unit_locations[0], start_index:end_index
] = replacing_tensor_input
else:
tensor_input[_batch_idx, unit_locations, start_index:end_index] = replacing_tensor_input
tensor_input[
_batch_idx, unit_locations, start_index:end_index
] = replacing_tensor_input
return tensor_input
elif unit in {"h", "h.pos"}:
# head-based scattering is only special for transformer-based model
# replacing_tensor_input: b_s, num_h, s, h_dim -> b_s, s, num_h*h_dim
old_shape = tensor_input.size() # b_s, s, -1*num_h*d
new_shape = tensor_input.size()[:-1] + (-1, meta_component.shape[1], last_dim) # b_s, s, -1, num_h, d
old_shape = tensor_input.size() # b_s, s, -1*num_h*d
new_shape = tensor_input.size()[:-1] + (
-1,
meta_component.shape[1],
last_dim,
) # b_s, s, -1, num_h, d
# get whether split by QKV
if component in type_to_module_mapping[model_type] and \
len(type_to_module_mapping[model_type][component]) > 2 and \
type_to_module_mapping[model_type][component][2][0] == split_three:
if (
component in type_to_module_mapping[model_type]
and len(type_to_module_mapping[model_type][component]) > 2
and type_to_module_mapping[model_type][component][2][0] == split_three
):
_slice_idx = type_to_module_mapping[model_type][component][2][1]
else:
_slice_idx = 0
tensor_permute = tensor_input.view(new_shape) # b_s, s, -1, num_h, d
tensor_permute = tensor_permute.permute(0, 3, 2, 1, 4) # b_s, num_h, -1, s, d
tensor_permute = tensor_input.view(new_shape) # b_s, s, -1, num_h, d
tensor_permute = tensor_permute.permute(0, 3, 2, 1, 4) # b_s, num_h, -1, s, d
if "." in unit:
# cannot advance indexing on two columns, thus a single for loop is unavoidable.
for i in range(unit_locations[0].shape[-1]):
tensor_permute[_batch_idx, unit_locations[0][:,[i]], _slice_idx, unit_locations[1]] = replacing_tensor_input[:,i]
tensor_permute[
_batch_idx, unit_locations[0][:, [i]], _slice_idx, unit_locations[1]
] = replacing_tensor_input[:, i]
else:
tensor_permute[_batch_idx, unit_locations, _slice_idx] = replacing_tensor_input
tensor_permute[
_batch_idx, unit_locations, _slice_idx
] = replacing_tensor_input
# permute back and reshape
tensor_output = tensor_permute.permute(0, 3, 2, 1, 4) # b_s, s, -1, num_h, d
tensor_output = tensor_output.view(old_shape) # b_s, s, -1*num_h*d
tensor_output = tensor_permute.permute(0, 3, 2, 1, 4) # b_s, s, -1, num_h, d
tensor_output = tensor_output.view(old_shape) # b_s, s, -1*num_h*d
return tensor_output
else:
if "." in unit:
# cannot advance indexing on two columns, thus a single for loop is unavoidable.
for i in range(unit_locations[0].shape[-1]):
tensor_input[_batch_idx, unit_locations[0][:,[i]], unit_locations[1]] = replacing_tensor_input[:,i]
tensor_input[
_batch_idx, unit_locations[0][:, [i]], unit_locations[1]
] = replacing_tensor_input[:, i]
else:
tensor_input[_batch_idx, unit_locations] = replacing_tensor_input
return tensor_input
Expand All @@ -407,7 +423,7 @@ def do_intervention(
"""Do the actual intervention."""

num_unit = base_representation.shape[1]

# flatten
original_base_shape = base_representation.shape
if len(original_base_shape) == 2 or (
Expand Down
Loading

0 comments on commit 91de11f

Please sign in to comment.