Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: add mlp transcoders #183

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ mkdocs-section-index = "^0.3.8"
mkdocstrings = "^0.24.1"
mkdocstrings-python = "^1.9.0"


[tool.poetry.group.tutorials.dependencies]
ipykernel = "^6.29.4"

[tool.poetry.extras]
mamba = ["mamba-lens"]

Expand Down
18 changes: 18 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,24 @@
return cls(**cfg)


@dataclass
class LanguageModelTranscoderRunnerConfig(LanguageModelSAERunnerConfig):
d_out: int = 512
hook_name_out: str = "blocks.0.hook_mlp_out"
hook_layer_out: int = 0
hook_head_index_out: Optional[int] = None

def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"""Returns the config for the base Transcoder."""
return {

Check warning on line 419 in sae_lens/config.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/config.py#L419

Added line #L419 was not covered by tests
**super().get_base_sae_cfg_dict(),
"d_out": self.d_out,
"hook_name_out": self.hook_name_out,
"hook_layer_out": self.hook_layer_out,
"hook_head_index_out": self.hook_head_index_out,
}


@dataclass
class CacheActivationsRunnerConfig:
"""
Expand Down
147 changes: 99 additions & 48 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@
}


@dataclass
class TranscoderConfig(SAEConfig):
# transcoder-specific forward pass details
d_out: int

# transcoder-specific dataset details
hook_name_out: str
hook_layer_out: int
hook_head_index_out: Optional[int]


class SAE(HookedRootModule):
"""
Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
Expand Down Expand Up @@ -216,48 +227,48 @@
feature_acts = self.encode(x)
sae_out = self.decode(feature_acts)

if self.use_error_term:
with torch.no_grad():
# Recompute everything without hooks to get true error term
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
# This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
# NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.

# move x to correct dtype
x = x.to(self.dtype)

# handle hook z reshaping if needed.
sae_in = self.reshape_fn_in(x) # type: ignore

# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)
if not self.use_error_term:
return self.hook_sae_output(sae_out)

# apply b_dec_to_input if using that method.
sae_in_cent = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in_cent @ self.W_enc + self.b_enc
feature_acts = self.activation_fn(hidden_pre)
x_reconstruct_clean = self.reshape_fn_out(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec
+ self.b_dec,
d_head=self.d_head,
)

sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)

return self.hook_sae_output(sae_out + sae_error)

return self.hook_sae_output(sae_out)
# If using error term, compute the error term and add it to the output
with torch.no_grad():
# Recompute everything without hooks to get true error term
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
# This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
# NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
feature_acts_clean = self.encode(x, apply_hooks=False)
x_reconstruct_clean = self.decode(feature_acts_clean, apply_hooks=False)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)

def encode(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], apply_hooks: bool = True
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calcuate SAE features from inputs
"""
sae_in = self.get_sae_in(x)
if apply_hooks:
sae_in = self.hook_sae_input(sae_in)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in @ self.W_enc + self.b_enc
if apply_hooks:
hidden_pre = self.hook_sae_acts_pre(hidden_pre)

feature_acts = self.activation_fn(hidden_pre)
if apply_hooks:
feature_acts = self.hook_sae_acts_post(feature_acts)

return feature_acts

def get_sae_in(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_in_reshaped"]:
"""Get the input to the SAE.

Fixes dtype, reshapes, normalizes, and applies b_dec if necessary.
"""
# move x to correct dtype
x = x.to(self.dtype)

Expand All @@ -268,31 +279,33 @@
x = self.run_time_activation_norm_fn_in(x)

# apply b_dec_to_input if using that method.
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

return feature_acts
sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)
return sae_in

def decode(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
self, feature_acts: Float[torch.Tensor, "... d_sae"], apply_hooks: bool = True
) -> Float[torch.Tensor, "... d_in"]:
"""Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
# "... d_sae, d_sae d_in -> ... d_in",
sae_out = self.hook_sae_recons(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
)
sae_recons = self.get_sae_recons(feature_acts)
if apply_hooks:
sae_recons = self.hook_sae_recons(sae_recons)

# handle run time activation normalization if needed
# will fail if you call this twice without calling encode in between.
sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_recons = self.run_time_activation_norm_fn_out(sae_recons)

# handle hook z reshaping if needed.
sae_out = self.reshape_fn_out(sae_out, self.d_head) # type: ignore
sae_recons = self.reshape_fn_out(sae_recons, self.d_head) # type: ignore

return sae_out
return sae_recons

def get_sae_recons(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
return (
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
)

@torch.no_grad()
def fold_W_dec_norm(self):
Expand Down Expand Up @@ -443,3 +456,41 @@
return tanh_relu
else:
raise ValueError(f"Unknown activation function: {activation_fn}")


class Transcoder(SAE):
"""A variant of sparse autoencoders that have different input and output hook points."""

cfg: TranscoderConfig # type: ignore
dtype: torch.dtype
device: torch.device

def __init__(
self,
cfg: TranscoderConfig,
use_error_term: bool = False,
):
assert isinstance(
cfg, TranscoderConfig
), f"Expected TranscoderConfig, got {cfg}"
if use_error_term:
raise NotImplementedError("Error term not yet supported for Transcoder")

Check warning on line 477 in sae_lens/sae.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae.py#L477

Added line #L477 was not covered by tests
super().__init__(cfg, use_error_term)

def initialize_weights_basic(self):
super().initialize_weights_basic()

# NOTE: Transcoders have an additional b_dec_out parameter.
# Reference: https://github.com/jacobdunefsky/transcoder_circuits/blob/7b44d870a5a301ef29eddfd77cb1f4dca854760a/sae_training/sparse_autoencoder.py#L93C1-L97C14
self.b_dec_out = nn.Parameter(
torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
)
Comment on lines +483 to +487

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the extra bias is needed. I'm probably just confused and missing something, but it would make the implementation simpler if you don't need it.

I understand that in normal SAEs people sometimes subtract b_dec from the input. This isn't really necessary but has a nice interpretation of choosing a new "0 point" which you can consider as the origin in the feature basis.

For transcoders this makes less sense. Since you aren't reconstructing the same activations you probably don't want to tie the pre-encoder bias with the post-decoder bias.

Thus, in the current implementation we do:
$$z = ReLU(W_{enc}(x - b_{dec}) + b_{enc})$$
and
$$out = W_{dec} x +b_\text{dec out}$$
This isn't any more expressive, you can always fold the first two biases ($b_{dec}$ and $b_{enc}$ above) into a single bias term. I don't see a good reason why it would result in a more interpretable zero point for the encoder basis either.

Overall I'd recommend dropping the complexity here, which maybe means you can just eliminate the Transcoder class entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense! i'll try dropping the extra b_dec term when training. I was initially concerned about supporting the previously-trained checkpoints, but as you say weight folding should solve that.


def get_sae_recons(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_out"]:
# NOTE: b_dec_out instead of b_dec
return (
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec
+ self.b_dec_out
)
102 changes: 99 additions & 3 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@
from safetensors.torch import save_file
from transformer_lens.hook_points import HookedRootModule

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.config import (
LanguageModelSAERunnerConfig,
LanguageModelTranscoderRunnerConfig,
)
from sae_lens.load_model import load_model
from sae_lens.sae import SAE_CFG_PATH, SAE_WEIGHTS_PATH, SPARSITY_PATH
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.geometric_median import compute_geometric_median
from sae_lens.training.sae_trainer import SAETrainer
from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
from sae_lens.training.sae_trainer import SAETrainer, TranscoderTrainer
from sae_lens.training.training_sae import (
TrainingSAE,
TrainingSAEConfig,
TrainingTranscoder,
TrainingTranscoderConfig,
)


class InterruptedException(Exception):
Expand Down Expand Up @@ -208,3 +216,91 @@
wandb.log_artifact(sparsity_artifact)

return checkpoint_path


class TranscoderTrainingRunner(SAETrainingRunner):
cfg: LanguageModelTranscoderRunnerConfig # type: ignore
sae: TrainingTranscoder # type: ignore
activations_store_out: ActivationsStore

def __init__(self, cfg: LanguageModelTranscoderRunnerConfig):
assert isinstance(

Check warning on line 227 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L227

Added line #L227 was not covered by tests
cfg, LanguageModelTranscoderRunnerConfig
), "cfg must be of type LanguageModelTranscoderRunnerConfig"
self.cfg = cfg # type: ignore

Check warning on line 230 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L230

Added line #L230 was not covered by tests

self.model = load_model(

Check warning on line 232 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L232

Added line #L232 was not covered by tests
self.cfg.model_class_name,
self.cfg.model_name,
device=self.cfg.device,
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
)

self.activations_store = ActivationsStore.from_config(

Check warning on line 239 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L239

Added line #L239 was not covered by tests
self.model,
self.cfg,
)

self.activations_store_out = ActivationsStore(

Check warning on line 244 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L244

Added line #L244 was not covered by tests
model=self.model,
dataset=cfg.dataset_path,
streaming=cfg.streaming,
# NOTE: this part is different!
d_in=cfg.d_out,
hook_name=cfg.hook_name_out,
hook_layer=cfg.hook_layer_out,
hook_head_index=cfg.hook_head_index_out,
# NOTE: end different part
context_size=cfg.context_size,
n_batches_in_buffer=cfg.n_batches_in_buffer,
total_training_tokens=cfg.training_tokens,
store_batch_size_prompts=cfg.store_batch_size_prompts,
train_batch_size_tokens=cfg.train_batch_size_tokens,
prepend_bos=cfg.prepend_bos,
normalize_activations=cfg.normalize_activations,
device=torch.device(cfg.act_store_device),
dtype=cfg.dtype,
cached_activations_path=cfg.cached_activations_path,
model_kwargs=cfg.model_kwargs,
autocast_lm=cfg.autocast_lm,
)

if self.cfg.from_pretrained_path is not None:
raise NotImplementedError()

Check warning on line 269 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L269

Added line #L269 was not covered by tests
else:
self.sae = TrainingTranscoder( # type: ignore

Check warning on line 271 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L271

Added line #L271 was not covered by tests
TrainingTranscoderConfig.from_dict(
self.cfg.get_training_sae_cfg_dict(),
)
)
self._init_sae_group_b_decs()

Check warning on line 276 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L276

Added line #L276 was not covered by tests

def run(self):
"""
Run the training of the SAE.
"""

if self.cfg.log_to_wandb:
wandb.init(

Check warning on line 284 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L284

Added line #L284 was not covered by tests
project=self.cfg.wandb_project,
config=cast(Any, self.cfg),
name=self.cfg.run_name,
id=self.cfg.wandb_id,
)

trainer = TranscoderTrainer(

Check warning on line 291 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L291

Added line #L291 was not covered by tests
model=self.model,
sae=self.sae,
activation_store=self.activations_store,
activation_store_out=self.activations_store_out,
save_checkpoint_fn=self.save_checkpoint,
cfg=self.cfg,
)

self._compile_if_needed()
sae = self.run_trainer_with_interruption_handling(trainer)

Check warning on line 301 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L300-L301

Added lines #L300 - L301 were not covered by tests

if self.cfg.log_to_wandb:
wandb.finish()

Check warning on line 304 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L304

Added line #L304 was not covered by tests

return sae

Check warning on line 306 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L306

Added line #L306 was not covered by tests
Loading
Loading