Skip to content

Commit

Permalink
Refactor ConceptEraser into LeaceFitter and LeaceEraser; bump to v0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Jul 5, 2023
1 parent 04517d4 commit 1fc1217
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 209 deletions.
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@ pip install concept-erasure

# Usage

`ConceptEraser` is the central class in this repo. It keeps track of the covariance and cross-covariance statistics needed to erase a concept, and lazily computes the LEACE parameters when needed.
The two main classes in this repo are `LeaceFitter` and `LeaceEraser`.

- `LeaceFitter` keeps track of the covariance and cross-covariance statistics needed to compute the LEACE erasure function. These statistics can be updated in an incremental fashion with `LeaceFitter.update()`. The erasure function is lazily computed when the `.eraser` property is accessed. This class uses O(_d<sup>2</sup>_) memory, where _d_ is the dimensionality of the representation, so you may want to discard it after computing the erasure function.
- `LeaceEraser` is a compact representation of the LEACE erasure function, using only O(_dk_) memory, where _k_ is the number of classes in the concept you're trying to erase (or equivalently, the _dimensionality_ of the concept if it's not categorical).

## Batch usage

In most cases, you probably have a batch of feature vectors `X` and concept labels `Z` and want to erase the concept from `X`. The easiest way to do this is using `ConceptEraser.fit()` followed by `ConceptEraser.forward()`:
In most cases, you probably have a batch of feature vectors `X` and concept labels `Z` and want to erase the concept from `X`. The easiest way to do this is by using the `LeaceEraser.fit()` convenience method:

```python
import torch
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression

from concept_erasure import ConceptEraser
from concept_erasure import LeaceEraser

n, d, k = 2048, 128, 2

Expand All @@ -40,7 +43,7 @@ real_lr = LogisticRegression(max_iter=1000).fit(X, Y)
beta = torch.from_numpy(real_lr.coef_)
assert beta.norm(p=torch.inf) > 0.1

eraser = ConceptEraser.fit(X_t, Y_t)
eraser = LeaceEraser.fit(X_t, Y_t)
X_ = eraser(X_t)

# But learns nothing after
Expand All @@ -50,10 +53,10 @@ assert beta.norm(p=torch.inf) < 1e-4
```

## Streaming usage
If you have a **stream** of data, you can use `ConceptEraser.update()` to update the statistics and `ConceptEraser.forward()` to erase the concept. This is useful if you have a large dataset and want to avoid storing it all in memory.
If you have a **stream** of data, you can use `LeaceFitter.update()` to update the statistics. This is useful if you have a large dataset and want to avoid storing it all in memory.

```python
from concept_erasure import ConceptEraser
from concept_erasure import LeaceFitter
from sklearn.datasets import make_classification
import torch

Expand All @@ -68,14 +71,14 @@ X, Y = make_classification(
X_t = torch.from_numpy(X)
Y_t = torch.from_numpy(Y)

eraser = ConceptEraser(d, 1, dtype=X_t.dtype)
fitter = LeaceFitter(d, 1, dtype=X_t.dtype)

# Compute cross-covariance matrix using batched updates
for x, y in zip(X_t.chunk(2), Y_t.chunk(2)):
eraser.update(x, y)
fitter.update(x, y)

# Erase the concept from the data
x_ = eraser(X_t[0])
x_ = fitter.eraser(X_t[0])
```

# Paper replication
Expand Down
7 changes: 5 additions & 2 deletions concept_erasure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .concept_eraser import ConceptEraser, ErasureMethod
from .concept_scrubber import ConceptScrubber
from .data import chunk_and_tokenize
from .leace import ErasureMethod, LeaceEraser, LeaceFitter
from .random_scrub import random_scrub
from .shrinkage import optimal_linear_shrinkage
from .utils import assert_type, chunk

Expand All @@ -9,7 +10,9 @@
"chunk",
"chunk_and_tokenize",
"optimal_linear_shrinkage",
"ConceptEraser",
"random_scrub",
"ConceptScrubber",
"LeaceEraser",
"LeaceFitter",
"ErasureMethod",
]
41 changes: 41 additions & 0 deletions concept_erasure/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from functools import wraps
from typing import Callable


def cached_property(func: Callable) -> property:
"""Decorator that converts a method into a lazily-evaluated cached property"""
# Create a secret attribute name for the cached property
attr_name = "_cached_" + func.__name__

@property
@wraps(func)
def _cached_property(self):
# If the secret attribute doesn't exist, compute the property and set it
if not hasattr(self, attr_name):
setattr(self, attr_name, func(self))

# Otherwise, return the cached property
return getattr(self, attr_name)

return _cached_property


def invalidates_cache(dependent_prop_name: str) -> Callable:
"""Invalidates a cached property when the decorated function is called"""
attr_name = "_cached_" + dependent_prop_name

# The actual decorator
def _invalidates_cache(func: Callable) -> Callable:
# The wrapper function
@wraps(func)
def wrapper(self, *args, **kwargs):
# Check if the secret attribute exists; if so delete it so that
# the cached property is recomputed
if hasattr(self, attr_name):
delattr(self, attr_name)

return func(self, *args, **kwargs)

return wrapper

return _invalidates_cache
74 changes: 10 additions & 64 deletions concept_erasure/concept_scrubber.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,107 +2,53 @@
from functools import partial
from typing import Callable

import torch
from torch import Tensor, nn
from transformers import PreTrainedModel

from .concept_eraser import ConceptEraser, ErasureMethod
from .leace import LeaceEraser
from .utils import assert_type, is_norm_layer, mangle_module_path


class ConceptScrubber(nn.Module):
"""Wrapper for a dictionary mapping module paths to `ConceptEraser` objects."""

@classmethod
def from_model(
cls,
model: PreTrainedModel,
z_dim: int = 1,
affine: bool = True,
module_suffix: str = "",
method: ErasureMethod = "leace",
pre_hook: bool = False,
):
"""Create a scrubber with a `ConceptEraser` for each norm layer in `model`."""
d_model = model.config.hidden_size

scrubber = cls(pre_hook=pre_hook)
scrubber.erasers.update(
{
mangle_module_path(name): ConceptEraser(
d_model,
z_dim,
affine=affine,
device=model.device,
method=method,
)
# Note that we are unwrapping the base model here
for name, mod in model.base_model.named_modules()
if is_norm_layer(mod) and name.endswith(module_suffix)
}
)
return scrubber
class ConceptScrubber:
"""Wrapper for a dictionary mapping module paths to `LeaceEraser` objects."""

def __init__(self, pre_hook: bool = False):
super().__init__()

self.erasers = nn.ModuleDict()
self.erasers: dict[str, LeaceEraser] = {}
self.pre_hook = pre_hook

@contextmanager
def scrub(self, model):
"""Add hooks to the model which apply the erasers during a forward pass."""

def scrub_hook(eraser: ConceptEraser, x: Tensor):
def scrub_hook(key: str, x: Tensor):
eraser = assert_type(LeaceEraser, self.erasers[key])
return eraser(x).type_as(x)

with self.apply_hook(model, scrub_hook):
yield self

@contextmanager
def random_scrub(self, model):
eraser = assert_type(ConceptEraser, next(iter(self.erasers.values())))
d = eraser.mean_x.shape[0]

u = eraser.mean_x.new_zeros(d, eraser.z_dim)
u = nn.init.orthogonal_(u)
P = torch.eye(d, device=u.device) - u @ u.T

def scrub_hook(eraser: ConceptEraser, x: Tensor):
mean = eraser.mean_x

if eraser.affine:
_x = (x.type_as(mean) - mean) @ P.T + mean
else:
_x = x.type_as(mean) @ P.T

return _x.type_as(x)

with self.apply_hook(model, scrub_hook):
yield self

@contextmanager
def apply_hook(
self,
model: nn.Module,
hook_fn: Callable[[ConceptEraser, Tensor], Tensor | None],
hook_fn: Callable[[str, Tensor], Tensor | None],
):
"""Apply a `hook_fn` to each submodule in `model` that we're scrubbing."""

def post_wrapper(_, __, output, name: str) -> Tensor | None:
key = mangle_module_path(name)
eraser = assert_type(ConceptEraser, self.erasers[key])
return hook_fn(eraser, output)
return hook_fn(key, output)

def pre_wrapper(_, inputs, name: str) -> tuple[Tensor | None, ...]:
x, *extras = inputs
key = mangle_module_path(name)
eraser = assert_type(ConceptEraser, self.erasers[key])
return hook_fn(eraser, x), *extras
return hook_fn(key, x), *extras

# Unwrap the base model if necessary
if isinstance(model, PreTrainedModel):
model = model.base_model
model = assert_type(PreTrainedModel, model.base_model)

handles = [
(
Expand Down
Loading

0 comments on commit 1fc1217

Please sign in to comment.