diff --git a/experiments/run_train_transcoder.py b/experiments/run_train_transcoder.py new file mode 100644 index 00000000..0dc22fb0 --- /dev/null +++ b/experiments/run_train_transcoder.py @@ -0,0 +1,123 @@ +import os + +import torch +from simple_parsing import ArgumentParser + +from sae_lens.config import LanguageModelTranscoderRunnerConfig +from sae_lens.sae_training_runner import TranscoderTrainingRunner + + +def setup_env_vars(): + # Set the environment variables for the cache and the dataset. + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def get_default_config(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + total_training_steps = 500 + batch_size = 4096 + total_training_tokens = total_training_steps * batch_size + print(f"Total Training Tokens: {total_training_tokens}") + + lr_warm_up_steps = 0 + lr_decay_steps = 40_000 + print(f"lr_decay_steps: {lr_decay_steps}") + l1_warmup_steps = 10_000 + print(f"l1_warmup_steps: {l1_warmup_steps}") + + return LanguageModelTranscoderRunnerConfig( + # Pick a tiny model to make this easier. + model_name="gelu-1l", + ## MLP Layer 0 ## + hook_name="blocks.0.ln2.hook_normalized", + hook_name_out="blocks.0.hook_mlp_out", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points) + hook_layer=0, # Only one layer in the model. + hook_layer_out=0, # Only one layer in the model. + d_in=512, # the width of the mlp input. + d_out=512, # the width of the mlp output. + dataset_path="NeelNanda/c4-tokenized-2b", + context_size=256, + is_dataset_tokenized=True, + prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. + # How big do we want our SAE to be? + expansion_factor=16, + # Dataset / Activation Store + # When we do a proper test + # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) + # For now. + training_tokens=total_training_tokens, # For initial testing I think this is a good number. + train_batch_size_tokens=4096, + # Loss Function + ## Reconstruction Coefficient. + mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. + ## Anthropic does not mention using an Lp norm other than L1. + l1_coefficient=5, + lp_norm=1.0, + # Instead, they multiply the L1 loss contribution + # from each feature of the activations by the decoder norm of the corresponding feature. + scale_sparsity_penalty_by_decoder_norm=True, + # Learning Rate + lr_scheduler_name="constant", # we set this independently of warmup and decay steps. + l1_warm_up_steps=l1_warmup_steps, + lr_warm_up_steps=lr_warm_up_steps, + lr_decay_steps=lr_warm_up_steps, + ## No ghost grad term. + use_ghost_grads=False, + # Initialization / Architecture + apply_b_dec_to_input=False, + # encoder bias zero's. (I'm not sure what it is by default now) + # decoder bias zero's. + b_dec_init_method="zeros", + normalize_sae_decoder=False, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + # Optimizer + lr=4e-5, + ## adam optimizer has no weight decay by default so worry about this. + adam_beta1=0.9, + adam_beta2=0.999, + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=64, + store_batch_size_prompts=16, + normalize_activations="constant_norm_rescale", + # Feature Store + feature_sampling_window=1000, + dead_feature_window=1000, + dead_feature_threshold=1e-4, + # performance enhancement: + compile_sae=True, + # WANDB + log_to_wandb=True, # always use wandb unless you are just testing code. + wandb_project="benchmark", + wandb_log_frequency=100, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="float32", + ) + + +def run_training(cfg: LanguageModelTranscoderRunnerConfig): + sae = TranscoderTrainingRunner(cfg).run() + assert sae is not None + # know whether or not this works by looking at the dashboard! # know whether or not this works by looking at the dashboard! + + +if __name__ == "__main__": + + parser = ArgumentParser() + parser.add_arguments( + LanguageModelTranscoderRunnerConfig, "cfg", default=get_default_config() + ) + args = parser.parse_args() + setup_env_vars() + run_training(args.cfg) diff --git a/pyproject.toml b/pyproject.toml index ce909dd8..33618008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,11 @@ mkdocs-section-index = "^0.3.8" mkdocstrings = "^0.24.1" mkdocstrings-python = "^1.9.0" + +[tool.poetry.group.tutorials.dependencies] +ipykernel = "^6.29.4" +simple-parsing = "^0.1.5" + [tool.poetry.extras] mamba = ["mamba-lens"] diff --git a/sae_lens/config.py b/sae_lens/config.py index 9893f450..4ad23d22 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -407,6 +407,24 @@ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig": return cls(**cfg) +@dataclass +class LanguageModelTranscoderRunnerConfig(LanguageModelSAERunnerConfig): + d_out: int = 512 + hook_name_out: str = "blocks.0.hook_mlp_out" + hook_layer_out: int = 0 + hook_head_index_out: Optional[int] = None + + def get_base_sae_cfg_dict(self) -> dict[str, Any]: + """Returns the config for the base Transcoder.""" + return { + **super().get_base_sae_cfg_dict(), + "d_out": self.d_out, + "hook_name_out": self.hook_name_out, + "hook_layer_out": self.hook_layer_out, + "hook_head_index_out": self.hook_head_index_out, + } + + @dataclass class CacheActivationsRunnerConfig: """ diff --git a/sae_lens/sae.py b/sae_lens/sae.py index d5fd2ea5..96bc706b 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -94,6 +94,17 @@ def to_dict(self) -> dict[str, Any]: } +@dataclass +class TranscoderConfig(SAEConfig): + # transcoder-specific forward pass details + d_out: int + + # transcoder-specific dataset details + hook_name_out: str + hook_layer_out: int + hook_head_index_out: Optional[int] + + class SAE(HookedRootModule): """ Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`. @@ -216,48 +227,48 @@ def forward( feature_acts = self.encode(x) sae_out = self.decode(feature_acts) - if self.use_error_term: - with torch.no_grad(): - # Recompute everything without hooks to get true error term - # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct - # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail - # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction. - - # move x to correct dtype - x = x.to(self.dtype) - - # handle hook z reshaping if needed. - sae_in = self.reshape_fn_in(x) # type: ignore - - # handle run time activation normalization if needed - sae_in = self.run_time_activation_norm_fn_in(sae_in) + if not self.use_error_term: + return self.hook_sae_output(sae_out) - # apply b_dec_to_input if using that method. - sae_in_cent = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) - - # "... d_in, d_in d_sae -> ... d_sae", - hidden_pre = sae_in_cent @ self.W_enc + self.b_enc - feature_acts = self.activation_fn(hidden_pre) - x_reconstruct_clean = self.reshape_fn_out( - self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec - + self.b_dec, - d_head=self.d_head, - ) - - sae_out = self.run_time_activation_norm_fn_out(sae_out) - sae_error = self.hook_sae_error(x - x_reconstruct_clean) - - return self.hook_sae_output(sae_out + sae_error) - - return self.hook_sae_output(sae_out) + # If using error term, compute the error term and add it to the output + with torch.no_grad(): + # Recompute everything without hooks to get true error term + # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct + # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail + # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction. + feature_acts_clean = self.encode(x, apply_hooks=False) + x_reconstruct_clean = self.decode(feature_acts_clean, apply_hooks=False) + sae_error = self.hook_sae_error(x - x_reconstruct_clean) + return self.hook_sae_output(sae_out + sae_error) def encode( - self, x: Float[torch.Tensor, "... d_in"] + self, x: Float[torch.Tensor, "... d_in"], apply_hooks: bool = True ) -> Float[torch.Tensor, "... d_sae"]: """ Calcuate SAE features from inputs """ + sae_in = self.get_sae_in(x) + if apply_hooks: + sae_in = self.hook_sae_input(sae_in) + # "... d_in, d_in d_sae -> ... d_sae", + hidden_pre = sae_in @ self.W_enc + self.b_enc + if apply_hooks: + hidden_pre = self.hook_sae_acts_pre(hidden_pre) + + feature_acts = self.activation_fn(hidden_pre) + if apply_hooks: + feature_acts = self.hook_sae_acts_post(feature_acts) + + return feature_acts + + def get_sae_in( + self, x: Float[torch.Tensor, "... d_in"] + ) -> Float[torch.Tensor, "... d_in_reshaped"]: + """Get the input to the SAE. + + Fixes dtype, reshapes, normalizes, and applies b_dec if necessary. + """ # move x to correct dtype x = x.to(self.dtype) @@ -268,31 +279,33 @@ def encode( x = self.run_time_activation_norm_fn_in(x) # apply b_dec_to_input if using that method. - sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) - - # "... d_in, d_in d_sae -> ... d_sae", - hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) - feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) - - return feature_acts + sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input) + return sae_in def decode( - self, feature_acts: Float[torch.Tensor, "... d_sae"] + self, feature_acts: Float[torch.Tensor, "... d_sae"], apply_hooks: bool = True ) -> Float[torch.Tensor, "... d_in"]: """Decodes SAE feature activation tensor into a reconstructed input activation tensor.""" # "... d_sae, d_sae d_in -> ... d_in", - sae_out = self.hook_sae_recons( - self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec - ) + sae_recons = self.get_sae_recons(feature_acts) + if apply_hooks: + sae_recons = self.hook_sae_recons(sae_recons) # handle run time activation normalization if needed # will fail if you call this twice without calling encode in between. - sae_out = self.run_time_activation_norm_fn_out(sae_out) + sae_recons = self.run_time_activation_norm_fn_out(sae_recons) # handle hook z reshaping if needed. - sae_out = self.reshape_fn_out(sae_out, self.d_head) # type: ignore + sae_recons = self.reshape_fn_out(sae_recons, self.d_head) # type: ignore - return sae_out + return sae_recons + + def get_sae_recons( + self, feature_acts: Float[torch.Tensor, "... d_sae"] + ) -> Float[torch.Tensor, "... d_in"]: + return ( + self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec + ) @torch.no_grad() def fold_W_dec_norm(self): @@ -443,3 +456,41 @@ def tanh_relu(input: torch.Tensor) -> torch.Tensor: return tanh_relu else: raise ValueError(f"Unknown activation function: {activation_fn}") + + +class Transcoder(SAE): + """A variant of sparse autoencoders that have different input and output hook points.""" + + cfg: TranscoderConfig # type: ignore + dtype: torch.dtype + device: torch.device + + def __init__( + self, + cfg: TranscoderConfig, + use_error_term: bool = False, + ): + assert isinstance( + cfg, TranscoderConfig + ), f"Expected TranscoderConfig, got {cfg}" + if use_error_term: + raise NotImplementedError("Error term not yet supported for Transcoder") + super().__init__(cfg, use_error_term) + + def initialize_weights_basic(self): + super().initialize_weights_basic() + + # NOTE: Transcoders have an additional b_dec_out parameter. + # Reference: https://github.com/jacobdunefsky/transcoder_circuits/blob/7b44d870a5a301ef29eddfd77cb1f4dca854760a/sae_training/sparse_autoencoder.py#L93C1-L97C14 + self.b_dec_out = nn.Parameter( + torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device) + ) + + def get_sae_recons( + self, feature_acts: Float[torch.Tensor, "... d_sae"] + ) -> Float[torch.Tensor, "... d_out"]: + # NOTE: b_dec_out instead of b_dec + return ( + self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + + self.b_dec_out + ) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 278fbe92..a9268ddd 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -8,13 +8,21 @@ from safetensors.torch import save_file from transformer_lens.hook_points import HookedRootModule -from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.config import ( + LanguageModelSAERunnerConfig, + LanguageModelTranscoderRunnerConfig, +) from sae_lens.load_model import load_model from sae_lens.sae import SAE_CFG_PATH, SAE_WEIGHTS_PATH, SPARSITY_PATH from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.geometric_median import compute_geometric_median -from sae_lens.training.sae_trainer import SAETrainer -from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig +from sae_lens.training.sae_trainer import SAETrainer, TranscoderTrainer +from sae_lens.training.training_sae import ( + TrainingSAE, + TrainingSAEConfig, + TrainingTranscoder, + TrainingTranscoderConfig, +) class InterruptedException(Exception): @@ -208,3 +216,91 @@ def save_checkpoint( wandb.log_artifact(sparsity_artifact) return checkpoint_path + + +class TranscoderTrainingRunner(SAETrainingRunner): + cfg: LanguageModelTranscoderRunnerConfig # type: ignore + sae: TrainingTranscoder # type: ignore + activations_store_out: ActivationsStore + + def __init__(self, cfg: LanguageModelTranscoderRunnerConfig): + assert isinstance( + cfg, LanguageModelTranscoderRunnerConfig + ), "cfg must be of type LanguageModelTranscoderRunnerConfig" + self.cfg = cfg # type: ignore + + self.model = load_model( + self.cfg.model_class_name, + self.cfg.model_name, + device=self.cfg.device, + model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs, + ) + + self.activations_store = ActivationsStore.from_config( + self.model, + self.cfg, + ) + + self.activations_store_out = ActivationsStore( + model=self.model, + dataset=cfg.dataset_path, + streaming=cfg.streaming, + # NOTE: this part is different! + d_in=cfg.d_out, + hook_name=cfg.hook_name_out, + hook_layer=cfg.hook_layer_out, + hook_head_index=cfg.hook_head_index_out, + # NOTE: end different part + context_size=cfg.context_size, + n_batches_in_buffer=cfg.n_batches_in_buffer, + total_training_tokens=cfg.training_tokens, + store_batch_size_prompts=cfg.store_batch_size_prompts, + train_batch_size_tokens=cfg.train_batch_size_tokens, + prepend_bos=cfg.prepend_bos, + normalize_activations=cfg.normalize_activations, + device=torch.device(cfg.act_store_device), + dtype=cfg.dtype, + cached_activations_path=cfg.cached_activations_path, + model_kwargs=cfg.model_kwargs, + autocast_lm=cfg.autocast_lm, + ) + + if self.cfg.from_pretrained_path is not None: + raise NotImplementedError() + else: + self.sae = TrainingTranscoder( # type: ignore + TrainingTranscoderConfig.from_dict( + self.cfg.get_training_sae_cfg_dict(), + ) + ) + self._init_sae_group_b_decs() + + def run(self): + """ + Run the training of the SAE. + """ + + if self.cfg.log_to_wandb: + wandb.init( + project=self.cfg.wandb_project, + config=cast(Any, self.cfg), + name=self.cfg.run_name, + id=self.cfg.wandb_id, + ) + + trainer = TranscoderTrainer( + model=self.model, + sae=self.sae, + activation_store=self.activations_store, + activation_store_out=self.activations_store_out, + save_checkpoint_fn=self.save_checkpoint, + cfg=self.cfg, + ) + + self._compile_if_needed() + sae = self.run_trainer_with_interruption_handling(trainer) + + if self.cfg.log_to_wandb: + wandb.finish() + + return sae diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 416a1675..97ab932e 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -9,11 +9,18 @@ from transformer_lens.hook_points import HookedRootModule from sae_lens import __version__ -from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.config import ( + LanguageModelSAERunnerConfig, + LanguageModelTranscoderRunnerConfig, +) from sae_lens.evals import run_evals from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import L1Scheduler, get_lr_scheduler -from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput +from sae_lens.training.training_sae import ( + TrainingSAE, + TrainingTranscoder, + TrainStepOutput, +) # used to map between parameters which are updated during finetuning and the config str. FINETUNING_PARAMETERS = { @@ -198,7 +205,7 @@ def _train_step( ) -> TrainStepOutput: sae.train() - # Make sure the W_dec is still zero-norm + # Make sure the W_dec is still unit-norm if self.cfg.normalize_sae_decoder: sae.set_decoder_norm_to_unit_norm() @@ -397,3 +404,117 @@ def _begin_finetuning_if_needed(self): param.requires_grad = False self.finetuning = True + + +class TranscoderTrainer(SAETrainer): + def __init__( + self, + model: HookedRootModule, + sae: TrainingTranscoder, + activation_store: ActivationsStore, + activation_store_out: ActivationsStore, + save_checkpoint_fn, # type: ignore + cfg: LanguageModelTranscoderRunnerConfig, + ) -> None: + super().__init__(model, sae, activation_store, save_checkpoint_fn, cfg) + self.activation_store_out = activation_store_out + + def _train_step( # type: ignore + self, + sae: TrainingSAE, + sae_in: torch.Tensor, + sae_target: torch.Tensor, + ) -> TrainStepOutput: + # NOTE: This is the same as the SAETrainer _train_step method, but with the sae_target added. + + sae.train() + # Make sure the W_dec is still unit-norm + if self.cfg.normalize_sae_decoder: + sae.set_decoder_norm_to_unit_norm() + + # log and then reset the feature sparsity every feature_sampling_window steps + if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0: + if self.cfg.log_to_wandb: + sparsity_log_dict = self._build_sparsity_log_dict() + wandb.log(sparsity_log_dict, step=self.n_training_steps) + self._reset_running_sparsity_stats() + + # for documentation on autocasting see: + # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html + with self.autocast_if_enabled: + assert isinstance( + self.sae, TrainingTranscoder + ), "sae should be a TrainingTranscoder" + train_step_output = self.sae.training_forward_pass( + sae_in=sae_in, + sae_target=sae_target, + dead_neuron_mask=self.dead_neurons, + current_l1_coefficient=self.current_l1_coefficient, + ) + + with torch.no_grad(): + did_fire = (train_step_output.feature_acts > 0).float().sum(-2) > 0 + self.n_forward_passes_since_fired += 1 + self.n_forward_passes_since_fired[did_fire] = 0 + self.act_freq_scores += ( + (train_step_output.feature_acts.abs() > 0).float().sum(0) + ) + self.n_frac_active_tokens += self.cfg.train_batch_size_tokens + + # Scaler will rescale gradients if autocast is enabled + self.scaler.scale( + train_step_output.loss + ).backward() # loss.backward() if not autocasting + self.scaler.unscale_(self.optimizer) # needed to clip correctly + # TODO: Work out if grad norm clipping should be in config / how to test it. + torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0) + self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting + self.scaler.update() + + if self.cfg.normalize_sae_decoder: + sae.remove_gradient_parallel_to_decoder_directions() + + self.optimizer.zero_grad() + self.lr_scheduler.step() + self.l1_scheduler.step() + + return train_step_output + + def fit(self) -> TrainingSAE: + # NOTE: This is the same as the SAETrainer fit method, but with the sae_target added. + + pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") + + self._estimate_norm_scaling_factor_if_needed() + + # Train loop + while self.n_training_tokens < self.cfg.total_training_tokens: + # Do a training step. + layer_acts = self.activation_store.next_batch()[:, 0, :] + layer_acts_out = self.activation_store_out.next_batch()[:, 0, :] + self.n_training_tokens += self.cfg.train_batch_size_tokens + + step_output = self._train_step( + sae=self.sae, sae_in=layer_acts, sae_target=layer_acts_out + ) + + if self.cfg.log_to_wandb: + self._log_train_step(step_output) + self._run_and_log_evals() + + self._checkpoint_if_needed() + self.n_training_steps += 1 + self._update_pbar(step_output, pbar) + + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) + self._begin_finetuning_if_needed() + + # save final sae group to checkpoints folder + self.save_checkpoint( + trainer=self, + checkpoint_name=f"final_{self.n_training_tokens}", + wandb_aliases=["final_model"], + ) + + pbar.close() + return self.sae diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 2f343e86..494a7c3c 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -12,7 +12,7 @@ from torch import nn from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.sae import SAE, SAEConfig +from sae_lens.sae import SAE, SAEConfig, Transcoder, TranscoderConfig from sae_lens.toolkit.pretrained_sae_loaders import ( load_pretrained_sae_lens_sae_components, ) @@ -82,7 +82,7 @@ def from_sae_runner_config( @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": - return TrainingSAEConfig(**config_dict) + return cls(**config_dict) def to_dict(self) -> dict[str, Any]: return { @@ -123,6 +123,27 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: } +@dataclass +class TrainingTranscoderConfig(TrainingSAEConfig): + d_out: int = 512 + hook_name_out: str = "blocks.0.hook_mlp_out" + hook_layer_out: int = 0 + hook_head_index_out: Optional[int] = None + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingTranscoderConfig": + return cls(**config_dict) + + def get_base_sae_cfg_dict(self) -> dict[str, Any]: + return { + **super().get_base_sae_cfg_dict(), + "d_out": self.d_out, + "hook_name_out": self.hook_name_out, + "hook_layer_out": self.hook_layer_out, + "hook_head_index_out": self.hook_head_index_out, + } + + class TrainingSAE(SAE): """ A SAE used for training. This class provides a `training_forward_pass` method which calculates @@ -153,8 +174,10 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": return cls(TrainingSAEConfig.from_dict(config_dict)) - def encode( - self, x: Float[torch.Tensor, "... d_in"] + # NOTE: The following type: ignore statement is because the parent class + # has additional kwargs for the encode() method. + def encode( # type: ignore + self, x: Float[torch.Tensor, "... d_in"], apply_hooks: bool = False ) -> Float[torch.Tensor, "... d_sae"]: """ Calcuate SAE features from inputs @@ -166,17 +189,7 @@ def encode_with_hidden_pre( self, x: Float[torch.Tensor, "... d_in"] ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: - # move x to correct dtype - x = x.to(self.dtype) - - # handle hook z reshaping if needed. - x = self.reshape_fn_in(x) # type: ignore - - # apply b_dec_to_input if using that method. - sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) - - # handle run time activation normalization if needed - x = self.run_time_activation_norm_fn_in(x) + sae_in = self.get_sae_in(x) # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) @@ -337,7 +350,9 @@ def load_from_pretrained( return sae def initialize_weights_complex(self): - """ """ + """Re-initialize the weights of the SAE.""" + # NOTE: initialize_weights_basic has been called in the parent class constructor + # so there's no need to re-initialize everything here. if self.cfg.decoder_orthogonal_init: self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T @@ -350,6 +365,7 @@ def initialize_weights_complex(self): ) self.initialize_decoder_norm_constant_norm() + # NOTE(dtch1997): This seems like it's duplicated at the end of the method. Should it be removed? elif self.cfg.normalize_sae_decoder: self.set_decoder_norm_to_unit_norm() @@ -430,3 +446,89 @@ def remove_gradient_parallel_to_decoder_directions(self): self.W_dec.data, "d_sae, d_sae d_in -> d_sae d_in", ) + + +class TrainingTranscoder(TrainingSAE, Transcoder): + """ + A transcoder used for training. This class provides a `training_forward_pass` method which calculates + losses used for training. + """ + + cfg: TrainingTranscoderConfig # type: ignore + use_error_term: bool + dtype: torch.dtype + device: torch.device + + def __init__(self, cfg: TrainingTranscoderConfig, use_error_term: bool = False): + + base_sae_cfg = TranscoderConfig.from_dict(cfg.get_base_sae_cfg_dict()) + Transcoder.__init__(self, base_sae_cfg) # type: ignore + self.cfg = cfg # type: ignore + self.use_error_term = use_error_term + + self.initialize_weights_complex() + + # The training SAE will assume that the activation store handles + # reshaping. + self.turn_off_forward_pass_hook_z_reshaping() + + self.mse_loss_fn = self._get_mse_loss_fn() + + def training_forward_pass( # type: ignore + self, + sae_in: torch.Tensor, + sae_target: torch.Tensor, + current_l1_coefficient: float, + dead_neuron_mask: Optional[torch.Tensor] = None, + ) -> TrainStepOutput: + # NOTE: This is exactly the same as the TrainingSAE class except that we use sae_target as target + + # do a forward pass to get SAE out, but we also need the + # hidden pre. + feature_acts, _ = self.encode_with_hidden_pre(sae_in) + sae_out = self.decode(feature_acts) + + # MSE LOSS + per_item_mse_loss = self.mse_loss_fn(sae_out, sae_target) + mse_loss = per_item_mse_loss.sum(dim=-1).mean() + + # GHOST GRADS + if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: + + # first half of second forward pass + _, hidden_pre = self.encode_with_hidden_pre(sae_in) + ghost_grad_loss = self.calculate_ghost_grad_loss( + x=sae_in, + sae_out=sae_out, + per_item_mse_loss=per_item_mse_loss, + hidden_pre=hidden_pre, + dead_neuron_mask=dead_neuron_mask, + ) + else: + ghost_grad_loss = 0.0 + + # SPARSITY LOSS + # either the W_dec norms are 1 and this won't do anything or they are not 1 + # and we're using their norm in the loss function. + weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1) + sparsity = weighted_feature_acts.norm( + p=self.cfg.lp_norm, dim=-1 + ) # sum over the feature dimension + + l1_loss = (current_l1_coefficient * sparsity).mean() + + loss = mse_loss + l1_loss + ghost_grad_loss + + return TrainStepOutput( + sae_in=sae_in, + sae_out=sae_out, + feature_acts=feature_acts, + loss=loss, + mse_loss=mse_loss.item(), + l1_loss=l1_loss.item(), + ghost_grad_loss=( + ghost_grad_loss.item() + if isinstance(ghost_grad_loss, torch.Tensor) + else ghost_grad_loss + ), + ) diff --git a/tests/benchmark/test_language_model_transcoder_runner.py b/tests/benchmark/test_language_model_transcoder_runner.py new file mode 100644 index 00000000..6a43d9c4 --- /dev/null +++ b/tests/benchmark/test_language_model_transcoder_runner.py @@ -0,0 +1,113 @@ +""" Benchmark test for transcoder training runner. + +Usage: +poetry run pytest tests/benchmark/test_language_model_transcoder_runner.py --profile-svg -s""" + +import torch + +from sae_lens.config import LanguageModelTranscoderRunnerConfig +from sae_lens.sae_training_runner import TranscoderTrainingRunner + +# os.environ["WANDB_MODE"] = "offline" # turn this off if you want to see the output + + +# The way to run this with this command: +# poetry run pytest tests/benchmark/test_language_model_sae_runner.py --profile-svg -s +def test_language_model_sae_runner(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + total_training_steps = 500 + batch_size = 4096 + total_training_tokens = total_training_steps * batch_size + print(f"Total Training Tokens: {total_training_tokens}") + + lr_warm_up_steps = 0 + lr_decay_steps = 40_000 + print(f"lr_decay_steps: {lr_decay_steps}") + l1_warmup_steps = 10_000 + print(f"l1_warmup_steps: {l1_warmup_steps}") + + cfg = LanguageModelTranscoderRunnerConfig( + # Pick a tiny model to make this easier. + model_name="gelu-1l", + ## MLP Layer 0 ## + hook_name="blocks.0.ln2.hook_normalized", + hook_name_out="blocks.0.hook_mlp_out", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points) + hook_layer=0, # Only one layer in the model. + hook_layer_out=0, # Only one layer in the model. + d_in=512, # the width of the mlp input. + d_out=512, # the width of the mlp output. + dataset_path="NeelNanda/c4-tokenized-2b", + context_size=256, + is_dataset_tokenized=True, + prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. + # How big do we want our SAE to be? + expansion_factor=16, + # Dataset / Activation Store + # When we do a proper test + # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) + # For now. + training_tokens=total_training_tokens, # For initial testing I think this is a good number. + train_batch_size_tokens=4096, + # Loss Function + ## Reconstruction Coefficient. + mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. + ## Anthropic does not mention using an Lp norm other than L1. + l1_coefficient=5, + lp_norm=1.0, + # Instead, they multiply the L1 loss contribution + # from each feature of the activations by the decoder norm of the corresponding feature. + scale_sparsity_penalty_by_decoder_norm=True, + # Learning Rate + lr_scheduler_name="constant", # we set this independently of warmup and decay steps. + l1_warm_up_steps=l1_warmup_steps, + lr_warm_up_steps=lr_warm_up_steps, + lr_decay_steps=lr_warm_up_steps, + ## No ghost grad term. + use_ghost_grads=False, + # Initialization / Architecture + apply_b_dec_to_input=False, + # encoder bias zero's. (I'm not sure what it is by default now) + # decoder bias zero's. + b_dec_init_method="zeros", + normalize_sae_decoder=False, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + # Optimizer + lr=4e-5, + ## adam optimizer has no weight decay by default so worry about this. + adam_beta1=0.9, + adam_beta2=0.999, + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=64, + store_batch_size_prompts=16, + normalize_activations="constant_norm_rescale", + # Feature Store + feature_sampling_window=1000, + dead_feature_window=1000, + dead_feature_threshold=1e-4, + # performance enhancement: + compile_sae=True, + # WANDB + log_to_wandb=True, # always use wandb unless you are just testing code. + wandb_project="benchmark", + wandb_log_frequency=100, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="float32", + ) + + # look at the next cell to see some instruction for what to do while this is running. + sae = TranscoderTrainingRunner(cfg).run() + + assert sae is not None + # know whether or not this works by looking at the dashboard! diff --git a/tests/unit/analysis/test_hooked_transcoder.py b/tests/unit/analysis/test_hooked_transcoder.py new file mode 100644 index 00000000..f1183ae2 --- /dev/null +++ b/tests/unit/analysis/test_hooked_transcoder.py @@ -0,0 +1,190 @@ +import einops +import pytest +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint + +from sae_lens import HookedSAETransformer +from sae_lens.sae import Transcoder, TranscoderConfig + +MODEL = "solu-1l" +prompt = "Hello World!" + + +class Counter: + def __init__(self): + self.count = 0 + + def inc(self, *args, **kwargs): # type: ignore + self.count += 1 + + +@pytest.fixture(scope="module") +def model(): + model = HookedSAETransformer.from_pretrained(MODEL, device="cpu") + yield model + model.reset_saes() + + +@pytest.fixture(scope="module") +def original_logits(model: HookedTransformer): + return model(prompt) + + +def get_hooked_mlp_transcoder(model: HookedTransformer, layer: int) -> Transcoder: + """Helper function to get a hooked MLP transcoder for a given layer of the model.""" + site_to_size = { + "ln2.hook_normalized": model.cfg.d_model, + "hook_mlp_out": model.cfg.d_model, + } + + site_in = "ln2.hook_normalized" + site_out = "hook_mlp_out" + hook_name = f"blocks.{layer}.{site_in}" + hook_name_out = f"blocks.{layer}.{site_out}" + d_in = site_to_size[site_in] + d_out = site_to_size[site_out] + + tc_cfg = TranscoderConfig( + d_in=d_in, + d_out=d_out, + d_sae=d_in * 2, + dtype="float32", + device="cpu", + model_name=MODEL, + hook_name=hook_name, + hook_name_out=hook_name_out, + hook_layer=layer, + hook_layer_out=layer, + hook_head_index=None, + hook_head_index_out=None, + activation_fn_str="relu", + prepend_bos=True, + context_size=128, + dataset_path="test", + apply_b_dec_to_input=False, + finetuning_scaling_factor=False, + sae_lens_training_version=None, + normalize_activations="none", + ) + + return Transcoder(tc_cfg) + + +@pytest.fixture( + scope="module", +) +def hooked_transcoder( + model: HookedTransformer, +) -> Transcoder: + return get_hooked_mlp_transcoder(model, 0) + + +def test_forward_reconstructs_input( + model: HookedTransformer, hooked_transcoder: Transcoder +): + """Verfiy that the Transcoder returns an output with the same shape as the input activations.""" + + # NOTE: In general, we do not expect the output of the transcoder to be equal to the input activations. + # However, for MLP transcoders specifically, the shapes do match. + act_name = hooked_transcoder.cfg.hook_name + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_output = hooked_transcoder(x) + assert sae_output.shape == x.shape + + +def test_run_with_cache(model: HookedTransformer, hooked_transcoder: Transcoder): + """Verifies that run_with_cache caches Transcoder activations""" + + act_name = hooked_transcoder.cfg.hook_name + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_output, cache = hooked_transcoder.run_with_cache(x) + assert sae_output.shape == x.shape + + assert "hook_sae_input" in cache + assert "hook_sae_acts_pre" in cache + assert "hook_sae_acts_post" in cache + assert "hook_sae_recons" in cache + assert "hook_sae_output" in cache + + +def test_run_with_hooks(model: HookedTransformer, hooked_transcoder: Transcoder): + """Verifies that run_with_hooks works with Transcoder activations""" + + c = Counter() + act_name = hooked_transcoder.cfg.hook_name + + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_hooks = [ + "hook_sae_input", + "hook_sae_acts_pre", + "hook_sae_acts_post", + "hook_sae_recons", + "hook_sae_output", + ] + + sae_output = hooked_transcoder.run_with_hooks( + x, fwd_hooks=[(sae_hook_name, c.inc) for sae_hook_name in sae_hooks] + ) + assert sae_output.shape == x.shape + + assert c.count == len(sae_hooks) + + +@pytest.mark.xfail +def test_error_term(model: HookedTransformer, hooked_transcoder: Transcoder): + """Verifies that that if we use error_terms, HookedTranscoder returns an output that is equal tdef test_feature_grads_with_error_term(model: HookedTransformer, hooked_transcoder: SparseAutoencoderBase): + o the input activations.""" + + act_name = hooked_transcoder.cfg.hook_name + hooked_transcoder.use_error_term = True + + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_output = hooked_transcoder(x) + assert sae_output.shape == x.shape + assert torch.allclose(sae_output, x, atol=1e-6) + + """Verifies that pytorch backward computes the correct feature gradients when using error_terms. Motivated by the need to compute feature gradients for attribution patching.""" + + act_name = hooked_transcoder.cfg.hook_name + hooked_transcoder.use_error_term = True + + # Get input activations + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + # Cache gradients with respect to feature acts + hooked_transcoder.reset_hooks() + grad_cache = {} + + def backward_cache_hook(act: torch.Tensor, hook: HookPoint): + grad_cache[hook.name] = act.detach() + + hooked_transcoder.add_hook("hook_sae_acts_post", backward_cache_hook, "bwd") + hooked_transcoder.add_hook("hook_sae_output", backward_cache_hook, "bwd") + + sae_output = hooked_transcoder(x) + assert torch.allclose(sae_output, x, atol=1e-6) + value = sae_output.sum() + value.backward() + hooked_transcoder.reset_hooks() + + # Compute gradient analytically + if act_name.endswith("hook_z"): + reshaped_output_grad = einops.rearrange( + grad_cache["hook_sae_output"], "... n_heads d_head -> ... (n_heads d_head)" + ) + analytic_grad = reshaped_output_grad @ hooked_transcoder.W_dec.T + else: + analytic_grad = grad_cache["hook_sae_output"] @ hooked_transcoder.W_dec.T + + # Compare analytic gradient with pytorch computed gradient + assert torch.allclose(grad_cache["hook_sae_acts_post"], analytic_grad, atol=1e-6) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 53143386..187375df 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -4,7 +4,7 @@ import pytest import torch -from tests.unit.helpers import TINYSTORIES_MODEL, load_model_cached +from .helpers import TINYSTORIES_MODEL, load_model_cached @pytest.fixture(autouse=True) diff --git a/tutorials/training_an_mlp_transcoder.ipynb b/tutorials/training_an_mlp_transcoder.ipynb new file mode 100644 index 00000000..cc4331ac --- /dev/null +++ b/tutorials/training_an_mlp_transcoder.ipynb @@ -0,0 +1,614 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "5O8tQblzOVHu" + }, + "source": [ + "# A very basic SAE Training Tutorial\n", + "\n", + "Please note that it is very easy for tutorial code to go stale so please have a low bar for raising an issue in the" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shAFb9-lOVHu" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "LeRi_tw2dhae" + }, + "outputs": [], + "source": [ + "try:\n", + " import google.colab # type: ignore\n", + " from google.colab import output\n", + " %pip install sae-lens transformer-lens circuitsvis\n", + "except:\n", + " from IPython import get_ipython # type: ignore\n", + " ipython = get_ipython(); assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uy-b3CcSOVHu", + "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/daniel/.cache/pypoetry/virtualenvs/sae-lens-kHvvStyh-py3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_lens.config import LanguageModelTranscoderRunnerConfig\n", + "from sae_lens.sae_training_runner import TranscoderTrainingRunner\n", + "\n", + "if torch.cuda.is_available():\n", + " device = \"cuda\"\n", + "elif torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cpu\"\n", + "\n", + "print(\"Using device:\", device)\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oe2nlqf-OVHv" + }, + "source": [ + "# Model Selection and Evaluation (Feel Free to Skip)\n", + "\n", + "We'll use the runner to train an SAE on a TinyStories Model. This is a very small model so we can train an SAE on it quite quickly. Before we get started, let's load in the model with `transformer_lens` and see what it can do.\n", + "\n", + "TransformerLens gives us 2 functions that are useful here (and circuits viz provides a third):\n", + "1. `transformer_lens.utils.test_prompt` will help us see when the model can infer one token.\n", + "2. `HookedTransformer.generate` will help us see what happens when we sample from the model.\n", + "3. `circuitsvis.logits.token_log_probs` will help us visualize the log probs of tokens at several positions in a prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "hFz6JUMuOVHv" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/daniel/.cache/pypoetry/virtualenvs/sae-lens-kHvvStyh-py3.11/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + } + ], + "source": [ + "from transformer_lens import HookedTransformer\n", + "\n", + "model = HookedTransformer.from_pretrained(\n", + " \"tiny-stories-1L-21M\"\n", + ") # This will wrap huggingface models and has lots of nice utilities." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aUiXrjdUOVHv" + }, + "source": [ + "### Getting a vibe for a model using `model.generate`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZfKT5aDOVHv" + }, + "source": [ + "Let's start by generating some stories using the model." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "G4ad4Zz1OVHv" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"Once upon a time there was a colorful rock. The powerful elephant was walking near a cave when she bumped into it. Each time, the elephant tried to push down on his trunk, but it didn't work. Then one day the elephant saw a zoo with the officers\"" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time passes through a park, the family praises their kindness. They say they are applauding everyone and the room is cheering them on.\\nWords: applaud, alligator, comfortable\\nStory: \\n\\nOne day, there was a little girl.'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, there was a pretty girl named Lily. She loved to explore and discover new things. One day, she decided to go on a hike in the forest.\\n\\nAs she was walking, she met a little bird who was very kind. The'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, a small girl named Tim hopped on a big gray stair. He had never seen a stair before and he was so excited!\\n\\nTim hopped and hopped all around the block and gasped in delight. He was so happy and counted to ten.'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, there was a little girl named True. She had an emergency. She reversed things! It was her own special thing to do and she could drive the car again. She drove it in her hand all the time.\\n\\nShe would never forget'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.\n", + "for i in range(5):\n", + " display(\n", + " model.generate(\n", + " \"Once upon a time\",\n", + " stop_at_eos=False, # avoids a bug on MPS\n", + " temperature=1,\n", + " verbose=False,\n", + " max_new_tokens=50,\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RDKr8o1xOVHv" + }, + "source": [ + "One thing we notice is that the model seems to be able to repeat the name of the main character very consistently. It can output a pronoun intead but in some stories will repeat the protagonists name. This seems like an interesting capability to analyse with SAEs. To better understand the models ability to remember the protagonists name, let's extract a prompt where the next character is determined and use the \"test_prompt\" utility from TransformerLens to check the ranking of the token for that name." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KsfJX-YpOVHv" + }, + "source": [ + "### Spot checking model abilities with `transformer_lens.utils.test_prompt`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "TpmPoj7uOVHv" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']\n", + "Tokenized answer: [' Lily']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+              "Rank: 1        Logit: 18.81 Prob: 13.46% Token: | Lily|\n",
+              "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.81\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m13.46\u001b[0m\u001b[1m% Token: | Lily|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|\n", + "Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|\n", + "Top 2th token. Logit: 17.35 Prob: 3.11% Token: | the|\n", + "Top 3th token. Logit: 17.26 Prob: 2.86% Token: | her|\n", + "Top 4th token. Logit: 16.74 Prob: 1.70% Token: | there|\n", + "Top 5th token. Logit: 16.43 Prob: 1.25% Token: | they|\n", + "Top 6th token. Logit: 15.80 Prob: 0.66% Token: | all|\n", + "Top 7th token. Logit: 15.64 Prob: 0.56% Token: | things|\n", + "Top 8th token. Logit: 15.28 Prob: 0.39% Token: | one|\n", + "Top 9th token. Logit: 15.24 Prob: 0.38% Token: | lived|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Lily', 1)]\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Lily'\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from transformer_lens.utils import test_prompt\n", + "\n", + "# Test the model with a prompt\n", + "test_prompt(\n", + " \"Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,\",\n", + " \" Lily\",\n", + " model,\n", + " prepend_space_to_answer=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jGzOvReDOVHv" + }, + "source": [ + "In the output above, we see that the model assigns ~ 70% probability to \"she\" being the next token, and a 13% chance to \" Lily\" being the next token. Other names like Lucy or Anna are not highly ranked." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HookedTransformer(\n", + " (embed): Embed()\n", + " (hook_embed): HookPoint()\n", + " (pos_embed): PosEmbed()\n", + " (hook_pos_embed): HookPoint()\n", + " (blocks): ModuleList(\n", + " (0): TransformerBlock(\n", + " (ln1): LayerNormPre(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (ln2): LayerNormPre(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (attn): Attention(\n", + " (hook_k): HookPoint()\n", + " (hook_q): HookPoint()\n", + " (hook_v): HookPoint()\n", + " (hook_z): HookPoint()\n", + " (hook_attn_scores): HookPoint()\n", + " (hook_pattern): HookPoint()\n", + " (hook_result): HookPoint()\n", + " )\n", + " (mlp): MLP(\n", + " (hook_pre): HookPoint()\n", + " (hook_post): HookPoint()\n", + " )\n", + " (hook_attn_in): HookPoint()\n", + " (hook_q_input): HookPoint()\n", + " (hook_k_input): HookPoint()\n", + " (hook_v_input): HookPoint()\n", + " (hook_mlp_in): HookPoint()\n", + " (hook_attn_out): HookPoint()\n", + " (hook_mlp_out): HookPoint()\n", + " (hook_resid_pre): HookPoint()\n", + " (hook_resid_mid): HookPoint()\n", + " (hook_resid_post): HookPoint()\n", + " )\n", + " )\n", + " (ln_final): LayerNormPre(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (unembed): Unembed()\n", + ")\n" + ] + } + ], + "source": [ + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "er3H1TDoOVHw" + }, + "source": [ + "# Training an SAE\n", + "\n", + "Now we're ready to train out SAE. We'll make a runner config, instantiate the runner and the rest is taken care of for us!\n", + "\n", + "During training, you use weights and biases to check key metrics which indicate how well we are able to optimize the variables we care about.\n", + "\n", + "To get a better sense of which variables to look at, you can read my (Joseph's) post [here](https://www.lesswrong.com/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream) and especially look at my weights and biases report [here](https://links-cdn.wandb.ai/wandb-public-images/links/jbloom/uue9i416.html).\n", + "\n", + "A few tips:\n", + "- Feel free to reorganize your wandb dashboard to put L0, CE_Loss_score, explained variance and other key metrics in one section at the top.\n", + "- Make a [run comparer](https://docs.wandb.ai/guides/app/features/panels/run-comparer) when tuning hyperparameters.\n", + "- You can download the resulting sparse autoencoder / sparsity estimate from wandb and upload them to huggingface if you want to share your SAE with other.\n", + " - cfg.json (training config)\n", + " - sae_weight.safetensors (model weights)\n", + " - sparsity.safetensors (sparsity estimate)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCHtPycOOVHw" + }, + "source": [ + "## MLP Out\n", + "\n", + "I've tuned the hyperparameters below for a decent SAE which achieves 86% CE Loss recovered and an L0 of ~85, and runs in about 2 hours on an M3 Max. You can get an SAE that looks better faster if you only consider L0 and CE loss but it will likely have more dense features and more dead features. Here's a link to my output with two runs with two different L1's: https://wandb.ai/jbloom/sae_lens_tutorial ." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "oAsZCAdJOVHw" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mdtch1997\u001b[0m (\u001b[33msae-experiments\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/daniel/ml_workspace/SAELens/tutorials/wandb/run-20240615_190734-yc4n5b77" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/sae-experiments/sae_lens_tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/sae-experiments/sae_lens_tutorial/runs/yc4n5b77" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training SAE: 0%| | 0/122880000 [00:00>\n", + "Traceback (most recent call last):\n", + " File \"/home/daniel/.cache/pypoetry/virtualenvs/sae-lens-kHvvStyh-py3.11/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n", + " def _clean_thread_parent_frames(\n", + "\n", + " File \"/home/daniel/ml_workspace/SAELens/sae_lens/sae_training_runner.py\", line 32, in interrupt_callback\n", + " raise InterruptedException()\n", + "sae_lens.sae_training_runner.InterruptedException: \n" + ] + } + ], + "source": [ + "total_training_steps = 30_000 # probably we should do more\n", + "batch_size = 4096\n", + "total_training_tokens = total_training_steps * batch_size\n", + "\n", + "lr_warm_up_steps = 0\n", + "lr_decay_steps = total_training_steps // 5 # 20% of training\n", + "l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", + "\n", + "cfg = LanguageModelTranscoderRunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_name=\"blocks.0.ln2.hook_normalized\",\n", + " hook_name_out=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_layer=0, # Only one layer in the model.\n", + " hook_layer_out=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp input.\n", + " d_out=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", + " is_dataset_tokenized=True,\n", + " streaming=True, # we could pre-download the token dataset if it was small.\n", + " # SAE Parameters\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", + " normalize_sae_decoder=False,\n", + " scale_sparsity_penalty_by_decoder_norm=True,\n", + " decoder_heuristic_init=True,\n", + " init_encoder_as_decoder_transpose=True,\n", + " normalize_activations=\"expected_average_only_in\",\n", + " # Training Parameters\n", + " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", + " adam_beta2=0.999,\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", + " l1_coefficient=5, # will control how sparse the feature activations are\n", + " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", + " train_batch_size_tokens=batch_size,\n", + " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " store_batch_size_prompts=16,\n", + " # Resampling protocol\n", + " use_ghost_grads=False, # we don't use ghost grads anymore.\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", + " # WANDB\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", + " wandb_log_frequency=30,\n", + " eval_every_n_wandb_logs=20,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=\"float32\"\n", + ")\n", + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder = TranscoderTrainingRunner(cfg).run()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}