Skip to content

Commit

Permalink
Merge pull request #102 from stanfordnlp/zen/binarymask
Browse files Browse the repository at this point in the history
[Minor] Add in trainable intervention based on binary mask intervention
  • Loading branch information
frankaging authored Feb 1, 2024
2 parents 8ad9f23 + 04692ee commit 14ff7c6
Show file tree
Hide file tree
Showing 4 changed files with 494 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .models.interventions import DistributedRepresentationIntervention
from .models.interventions import SourcelessIntervention
from .models.interventions import NoiseIntervention
from .models.interventions import SigmoidMaskIntervention


# Utils
Expand Down
3 changes: 2 additions & 1 deletion pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def set_temperature(self, temp: torch.Tensor):
Set temperature if needed
"""
for k, v in self.interventions.items():
if isinstance(v[0], BoundlessRotatedSpaceIntervention):
if isinstance(v[0], BoundlessRotatedSpaceIntervention) or \
isinstance(v[0], SigmoidMaskIntervention):
v[0].set_temperature(temp)

def enable_model_gradients(self):
Expand Down
35 changes: 34 additions & 1 deletion pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __init__(self, **kwargs):
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
# boundary masks are initialized to close to 1
self.masks = torch.nn.Parameter(
torch.tensor([100] * self.embed_dim), requires_grad=True
torch.tensor([100.0] * self.embed_dim), requires_grad=True
)
self.temperature = torch.nn.Parameter(torch.tensor(50.0))

Expand Down Expand Up @@ -398,6 +398,39 @@ def __str__(self):
return f"SigmoidMaskRotatedSpaceIntervention()"


class SigmoidMaskIntervention(TrainableIntervention, LocalistRepresentationIntervention):

"""Intervention in the original basis with binary mask."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mask = torch.nn.Parameter(
torch.zeros(self.embed_dim), requires_grad=True)

self.temperature = torch.nn.Parameter(torch.tensor(50.0))

def get_temperature(self):
return self.temperature

def set_temperature(self, temp: torch.Tensor):
self.temperature.data = temp

def forward(self, base, source, subspaces=None):
batch_size = base.shape[0]
# get boundary mask between 0 and 1 from sigmoid
mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))

# interchange
intervened_output = (
1.0 - mask_sigmoid
) * base + mask_sigmoid * source

return intervened_output

def __str__(self):
return f"SigmoidMaskIntervention()"


class LowRankRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):

"""Intervention in the rotated space."""
Expand Down
Loading

0 comments on commit 14ff7c6

Please sign in to comment.