diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 13158431..3c716682 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,4 @@ -## 📝 Description +## Description [Describe Your Changes Here ...] @@ -6,10 +6,10 @@ [Describe Your Changes Here ...] -# Checklist: +## Checklist: - [ ] My PR title strictly follows the format: `[Your Priority] Your Title` - [ ] I have attached the testing log above - [ ] I provide enough comments to my code - [ ] I have changed documentations -- [ ] I have added tests for my changes \ No newline at end of file +- [ ] I have added tests for my changes diff --git a/pyvene/models/gpt2/modelings_intervenable_gpt2.py b/pyvene/models/gpt2/modelings_intervenable_gpt2.py index b57349af..053a2b8b 100644 --- a/pyvene/models/gpt2/modelings_intervenable_gpt2.py +++ b/pyvene/models/gpt2/modelings_intervenable_gpt2.py @@ -22,7 +22,7 @@ "attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK), "head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), "attention_weight": ("h[%s].attn.attn_dropout", CONST_INPUT_HOOK), - "attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK), + "attention_output": ("h[%s].attn.resid_dropout", CONST_OUTPUT_HOOK), "attention_input": ("h[%s].attn", CONST_INPUT_HOOK), "query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0)), "key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1)), diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index e141b259..a946abe5 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -555,11 +555,11 @@ def _gather_intervention_output( self._intervention_reverse_link[representations_key] ] else: - # cold gather - original_output = output # data structure casting if isinstance(output, tuple): - original_output = output[0] + original_output = output[0].clone() + else: + original_output = output.clone() # gather subcomponent original_output = output_to_subcomponent( original_output, @@ -788,7 +788,7 @@ def hook_callback(model, args, kwargs, output=None): output = kwargs[list(kwargs.keys())[0]] else: output = args - + selected_output = self._gather_intervention_output( output, key, unit_locations_base[key_i] ) @@ -838,26 +838,15 @@ def hook_callback(model, args, kwargs, output=None): self._intervention_reverse_link[key] ] = intervened_representation.clone() - # very buggy due to tensor version - if self.model_has_grad: - # TODO: figure out how to allow this! - if isinstance(output, tuple): - raise ValueError( - "Model grad is not allowed when " - "intervening output is tuple type." - ) - output_c = output.clone() - # patched in the intervned activations in-place + if isinstance(output, tuple): _ = self._scatter_intervention_output( - output_c, intervened_representation, key, unit_locations_base[key_i] + output[0], intervened_representation, key, unit_locations_base[key_i] ) - output = output_c.clone() else: - # patched in the intervned activations in-place _ = self._scatter_intervention_output( output, intervened_representation, key, unit_locations_base[key_i] ) - + self._intervention_state[key].inc_setter_version() handlers.append(module_hook(hook_callback, with_kwargs=True)) @@ -1073,7 +1062,6 @@ def _wait_for_forward_with_serial_intervention( def _broadcast_unit_locations( self, batch_size, - intervention_group_size, unit_locations ): if self.mode == "parallel": @@ -1086,33 +1074,33 @@ def _broadcast_unit_locations( k = "sources->base" if isinstance(v, int): if is_base_only: - _unit_locations[k] = (None, [[[v]]*batch_size]*intervention_group_size) + _unit_locations[k] = (None, [[[v]]*batch_size]*len(self.interventions)) else: _unit_locations[k] = ( - [[[v]]*batch_size]*intervention_group_size, - [[[v]]*batch_size]*intervention_group_size + [[[v]]*batch_size]*len(self.interventions), + [[[v]]*batch_size]*len(self.interventions) ) self.use_fast = True elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int): _unit_locations[k] = ( - [[[v[0]]]*batch_size]*intervention_group_size, - [[[v[1]]]*batch_size]*intervention_group_size + [[[v[0]]]*batch_size]*len(self.interventions), + [[[v[1]]]*batch_size]*len(self.interventions) ) self.use_fast = True elif len(v) == 2 and v[0] == None and isinstance(v[1], int): - _unit_locations[k] = (None, [[[v[1]]]*batch_size]*intervention_group_size) + _unit_locations[k] = (None, [[[v[1]]]*batch_size]*len(self.interventions)) self.use_fast = True elif len(v) == 2 and isinstance(v[0], int) and v[1] == None: - _unit_locations[k] = ([[[v[0]]]*batch_size]*intervention_group_size, None) + _unit_locations[k] = ([[[v[0]]]*batch_size]*len(self.interventions), None) self.use_fast = True elif isinstance(v, list) and get_list_depth(v) == 1: # [0,1,2,3] -> [[[0,1,2,3]]], ... if is_base_only: - _unit_locations[k] = (None, [[v]*batch_size]*intervention_group_size) + _unit_locations[k] = (None, [[v]*batch_size]*len(self.interventions)) else: _unit_locations[k] = ( - [[v]*batch_size]*intervention_group_size, - [[v]*batch_size]*intervention_group_size + [[v]*batch_size]*len(self.interventions), + [[v]*batch_size]*len(self.interventions) ) self.use_fast = True else: @@ -1125,27 +1113,27 @@ def _broadcast_unit_locations( for k, v in unit_locations.items(): if isinstance(v, int): _unit_locations[k] = ( - [[[v]]*batch_size]*intervention_group_size, - [[[v]]*batch_size]*intervention_group_size + [[[v]]*batch_size]*len(self.interventions), + [[[v]]*batch_size]*len(self.interventions) ) self.use_fast = True elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int): _unit_locations[k] = ( - [[[v[0]]]*batch_size]*intervention_group_size, - [[[v[1]]]*batch_size]*intervention_group_size + [[[v[0]]]*batch_size]*len(self.interventions), + [[[v[1]]]*batch_size]*len(self.interventions) ) self.use_fast = True elif len(v) == 2 and v[0] == None and isinstance(v[1], int): - _unit_locations[k] = (None, [[[v[1]]]*batch_size]*intervention_group_size) + _unit_locations[k] = (None, [[[v[1]]]*batch_size]*len(self.interventions)) self.use_fast = True elif len(v) == 2 and isinstance(v[0], int) and v[1] == None: - _unit_locations[k] = ([[[v[0]]]*batch_size]*intervention_group_size, None) + _unit_locations[k] = ([[[v[0]]]*batch_size]*len(self.interventions), None) self.use_fast = True elif isinstance(v, list) and get_list_depth(v) == 1: # [0,1,2,3] -> [[[0,1,2,3]]], ... _unit_locations[k] = ( - [[v]*batch_size]*intervention_group_size, - [[v]*batch_size]*intervention_group_size + [[v]*batch_size]*len(self.interventions), + [[v]*batch_size]*len(self.interventions) ) self.use_fast = True else: @@ -1191,16 +1179,15 @@ def _broadcast_sources( def _broadcast_subspaces( self, batch_size, - intervention_group_size, subspaces ): """Broadcast simple subspaces input""" _subspaces = subspaces if isinstance(subspaces, int): - _subspaces = [[[subspaces]]*batch_size]*intervention_group_size + _subspaces = [[[subspaces]]*batch_size]*len(self.interventions) elif isinstance(subspaces, list) and isinstance(subspaces[0], int): - _subspaces = [[subspaces]*batch_size]*intervention_group_size + _subspaces = [[subspaces]*batch_size]*len(self.interventions) else: # TODO: subspaces is easier to add more broadcast majic. pass @@ -1292,13 +1279,11 @@ def forward( return self.model(**base), None # broadcast - unit_locations = self._broadcast_unit_locations( - get_batch_size(base), len(self._intervention_group), unit_locations) + unit_locations = self._broadcast_unit_locations(get_batch_size(base), unit_locations) sources = [None]*len(self._intervention_group) if sources is None else sources sources = self._broadcast_sources(sources) activations_sources = self._broadcast_source_representations(activations_sources) - subspaces = self._broadcast_subspaces( - get_batch_size(base), len(self._intervention_group), subspaces) + subspaces = self._broadcast_subspaces(get_batch_size(base), subspaces) self._input_validation( base, @@ -1412,13 +1397,11 @@ def generate( unit_locations = {"base": 0} # broadcast - unit_locations = self._broadcast_unit_locations( - get_batch_size(base), len(self._intervention_group), unit_locations) + unit_locations = self._broadcast_unit_locations(get_batch_size(base), unit_locations) sources = [None]*len(self._intervention_group) if sources is None else sources sources = self._broadcast_sources(sources) activations_sources = self._broadcast_source_representations(activations_sources) - subspaces = self._broadcast_subspaces( - get_batch_size(base), len(self._intervention_group), subspaces) + subspaces = self._broadcast_subspaces(get_batch_size(base), subspaces) self._input_validation( base, diff --git a/pyvene/models/interventions.py b/pyvene/models/interventions.py index 8cbffcfe..a63cc2a6 100644 --- a/pyvene/models/interventions.py +++ b/pyvene/models/interventions.py @@ -13,7 +13,7 @@ class Intervention(torch.nn.Module): def __init__(self, **kwargs): super().__init__() - self.trainble = False + self.trainable = False self.is_source_constant = False self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False @@ -87,7 +87,7 @@ class TrainableIntervention(Intervention): def __init__(self, **kwargs): super().__init__(**kwargs) - self.trainble = True + self.trainable = True self.is_source_constant = False def tie_weight(self, linked_intervention): @@ -204,7 +204,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention): def __init__(self, **kwargs): super().__init__(**kwargs) - def forward(self, base, source, subspaces=None): + def forward(self, base, source, subspaces=None): return _do_intervention_by_swap( base, source if self.source_representation is None else self.source_representation, @@ -478,7 +478,7 @@ def __init__(self, **kwargs): self.pca_std = torch.nn.Parameter( torch.tensor(pca_std, dtype=torch.float32), requires_grad=False ) - self.trainble = False + self.trainable = False def forward(self, base, source, subspaces=None): base_norm = (base - self.pca_mean) / self.pca_std diff --git a/pyvene_101.ipynb b/pyvene_101.ipynb index 4f9a3c24..25e7658c 100644 --- a/pyvene_101.ipynb +++ b/pyvene_101.ipynb @@ -68,6 +68,7 @@ " 1. [Causal Tracing](#Composing-Complex-Intervention-Schema:-Causal-Tracing-in-15-lines)\n", " 1. [Inference-time Intervention](#Inference-time-Intervention)\n", " 1. [IntervenableModel from HuggingFace Directly](#IntervenableModel-from-HuggingFace-Directly)\n", + " 1. [Path Patching with DAS](#Path-Patching-with-Trainable-Interventions)\n", "1. [The End](#The-End)\n", " " ] @@ -1446,7 +1447,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 2, "id": "8c7dde89", "metadata": {}, "outputs": [ @@ -1454,8 +1455,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "loaded model\n" + "loaded model\n", + "number of params: 124439808\n" ] + }, + { + "data": { + "text/plain": [ + "tensor([[[ 0.0022, -0.1783, -0.2780, ..., 0.0477, -0.2069, 0.1093],\n", + " [ 0.0385, 0.0886, -0.6608, ..., 0.0104, -0.4946, 0.6148],\n", + " [ 0.2377, -0.2312, 0.0308, ..., 0.1085, 0.0456, 0.2494],\n", + " [-0.0034, 0.0088, -0.2219, ..., 0.1198, 0.0759, 0.3953],\n", + " [ 0.4635, 0.2698, -0.3185, ..., -0.2946, 0.2634, 0.2714]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -1465,12 +1482,24 @@ "_, tokenizer, gpt2 = pv.create_gpt2()\n", "\n", "pv_gpt2 = pv.IntervenableModel({\n", - " \"layer\": 8}, \n", + " \"layer\": 8, \"component\": \"block_output\"}, \n", " model=gpt2\n", ")\n", "\n", "pv_gpt2.enable_model_gradients()\n", - "# run counterfactual forward as usual" + "print(\"number of params:\", pv_gpt2.count_parameters())\n", + "\n", + "# run counterfactual forward as usual\n", + "base = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\")\n", + "sources = [\n", + " tokenizer(\"The capital of Italy is\", return_tensors=\"pt\"),\n", + "]\n", + "base_outputs, counterfactual_outputs = pv_gpt2(\n", + " base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}\n", + ")\n", + "print(counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state)\n", + "# call backward will put gradients on model's weights\n", + "counterfactual_outputs.last_hidden_state.sum().backward()" ] }, { @@ -2145,6 +2174,100 @@ "print(tokenizer.decode(iti_response_shared[0], skip_special_tokens=True))" ] }, + { + "cell_type": "markdown", + "id": "956b3b5f", + "metadata": {}, + "source": [ + "### Path Patching with Trainable Interventions" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "af501960", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loaded model\n" + ] + } + ], + "source": [ + "import pyvene as pv\n", + "\n", + "def path_patching_with_DAS_config(\n", + " layer, last_layer, low_rank_dimension,\n", + " component=\"attention_output\", unit=\"pos\"\n", + "):\n", + " intervening_component = [{\n", + " \"layer\": layer, \"component\": component, \"group_key\": 0,\n", + " \"intervention_type\": pv.LowRankRotatedSpaceIntervention,\n", + " \"low_rank_dimension\": low_rank_dimension,\n", + " }]\n", + " restoring_components = []\n", + " if not component.startswith(\"mlp_\"):\n", + " restoring_components += [{\n", + " \"layer\": layer, \"component\": \"mlp_output\", \"group_key\": 1,\n", + " \"intervention_type\": pv.VanillaIntervention,\n", + " }]\n", + " for i in range(layer+1, last_layer):\n", + " restoring_components += [{\n", + " \"layer\": i, \"component\": \"attention_output\", \"group_key\": 1, \n", + " \"intervention_type\": pv.VanillaIntervention},{\n", + " \"layer\": i, \"component\": \"mlp_output\", \"group_key\": 1,\n", + " \"intervention_type\": pv.VanillaIntervention\n", + " }]\n", + " intervenable_config = pv.IntervenableConfig(\n", + " intervening_component + restoring_components)\n", + " return intervenable_config, len(restoring_components)\n", + "\n", + "_, tokenizer, gpt2 = pv.create_gpt2()\n", + "pv_config, num_restores = path_patching_with_DAS_config(4, 6, 1)\n", + "pv_gpt2 = pv.IntervenableModel(pv_config, model=gpt2)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "be63453e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-0.0694, grad_fn=)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\")\n", + "restore_source = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\")\n", + "source = tokenizer(\"The capital of Italy is\", return_tensors=\"pt\")\n", + "\n", + "# zero-out grads\n", + "_ = pv_gpt2.model.eval()\n", + "for k, v in pv_gpt2.interventions.items():\n", + " v[0].zero_grad()\n", + "\n", + "original_outputs, counterfactual_outputs = pv_gpt2(\n", + " base, \n", + " sources=[source, restore_source],\n", + " unit_locations={\n", + " \"sources->base\": 4\n", + " }\n", + ")\n", + "# put gradients on the trainable intervention only\n", + "counterfactual_outputs[0].sum().backward()" + ] + }, { "cell_type": "markdown", "id": "bc6eb49d", diff --git a/tests/integration_tests/InterventionWithMLPTestCase.py b/tests/integration_tests/InterventionWithMLPTestCase.py index db8515f0..1623a9b6 100644 --- a/tests/integration_tests/InterventionWithMLPTestCase.py +++ b/tests/integration_tests/InterventionWithMLPTestCase.py @@ -291,13 +291,8 @@ def test_no_intervention_link_negative(self): {"sources->base": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])}, subspaces=[[[0]] * b_s, [[1]] * b_s], ) - - try: - our_out_overwrite[0].sum().backward() - except RuntimeError: - pass - else: - raise AssertionError("RuntimeError by torch was not raised") + # it will work but the gradient is not accurate + our_out_overwrite[0].sum().backward() def suite(): diff --git a/tutorials/advanced_tutorials/tutorial_ioi_utils.py b/tutorials/advanced_tutorials/tutorial_ioi_utils.py index 8478a1a1..ed1f28c5 100644 --- a/tutorials/advanced_tutorials/tutorial_ioi_utils.py +++ b/tutorials/advanced_tutorials/tutorial_ioi_utils.py @@ -881,3 +881,212 @@ def find_variable_at( if return_intervenable: return data, intervenable return data + + +def path_patching_config( + layer, last_layer, low_rank_dimension, + component="attention_output", unit="pos" +): + intervening_component = [{ + "layer": layer, "component": component, + "unit": unit, "group_key": 0, + "intervention_type": LowRankRotatedSpaceIntervention, + "low_rank_dimension": low_rank_dimension, + }] + restoring_components = [] + if not component.startswith("mlp_"): + restoring_components += [{ + "layer": layer, "component": "mlp_output", "group_key": 1, + "intervention_type": VanillaIntervention, + }] + for i in range(layer+1, last_layer): + restoring_components += [{ + "layer": i, "component": "attention_output", "group_key": 1, + "intervention_type": VanillaIntervention},{ + "layer": i, "component": "mlp_output", "group_key": 1, + "intervention_type": VanillaIntervention + }] + intervenable_config = IntervenableConfig( + intervening_component + restoring_components) + return intervenable_config, len(restoring_components) + + +def with_path_patch_find_variable_at( + gpt2, + tokenizer, + positions, + layers, + stream, + low_rank_dimension=1, + seed=42, + debug=False, +): + transformers.set_seed(seed) + + train_distribution = PromptDistribution( + names=NAMES[: len(NAMES) // 2], + objects=OBJECTS[: len(OBJECTS) // 2], + places=PLACES[: len(PLACES) // 2], + prefix_len=2, + prefixes=PREFIXES, + templates=TEMPLATES[:2], + ) + + test_distribution = PromptDistribution( + names=NAMES[len(NAMES) // 2 :], + objects=OBJECTS[len(OBJECTS) // 2 :], + places=PLACES[len(PLACES) // 2 :], + prefix_len=2, + prefixes=PREFIXES, + templates=TEMPLATES[2:], + ) + + D_train = train_distribution.sample_das( + tokenizer=tokenizer, + base_patterns=["ABB", "BAB"], + source_patterns=["ABB", "BAB"], + labels="position", + samples_per_combination=50, + ) + D_test = test_distribution.sample_das( + tokenizer=tokenizer, + base_patterns=[ + "ABB", + ], + source_patterns=["BAB"], + labels="position", + samples_per_combination=50, + ) + test_distribution.sample_das( + tokenizer=tokenizer, + base_patterns=[ + "BAB", + ], + source_patterns=["ABB"], + labels="position", + samples_per_combination=50, + ) + + data = [] + + batch_size = 20 + eval_every = 5 + initial_lr = 0.01 + n_epochs = 10 + aligning_stream = stream + + for aligning_pos in positions: + for aligning_layer in layers: + if debug: + print( + f"finding name position at: pos->{aligning_pos}, " + f"layers->{aligning_layer}, stream->{stream}" + ) + config, num_restores = path_patching_config( + aligning_layer, + gpt2.config.n_layer, + low_rank_dimension, + component=aligning_stream, + unit="pos" + ) + intervenable = IntervenableModel(config, gpt2) + intervenable.set_device("cuda") + intervenable.disable_model_gradients() + total_step = 0 + optimizer = torch.optim.Adam( + intervenable.get_trainable_parameters(), lr=initial_lr + ) + scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, end_factor=0.1, total_iters=n_epochs + ) + + for epoch in range(n_epochs): + torch.cuda.empty_cache() + for batch_dataset in D_train.batches(batch_size=batch_size): + # prepare base + base_inputs = batch_dataset.base.tokens + b_s = base_inputs["input_ids"].shape[0] + for k, v in base_inputs.items(): + if v is not None and isinstance(v, torch.Tensor): + base_inputs[k] = v.to(gpt2.device) + # prepare source + source_inputs = batch_dataset.source.tokens + for k, v in source_inputs.items(): + if v is not None and isinstance(v, torch.Tensor): + source_inputs[k] = v.to(gpt2.device) + # prepare label + labels = batch_dataset.patched_answer_tokens[:, 0].to( + gpt2.device + ) + + assert all(x == 18 for x in batch_dataset.base.lengths) + assert all(x == 18 for x in batch_dataset.source.lengths) + + _, counterfactual_outputs = intervenable( + {"input_ids": base_inputs["input_ids"]}, + [{"input_ids": source_inputs["input_ids"]}, {"input_ids": base_inputs["input_ids"]}], + { + "sources->base": ( + [[[aligning_pos]] * b_s]+[[[aligning_pos]] * b_s]*num_restores, + [[[aligning_pos]] * b_s]+[[[aligning_pos]] * b_s]*num_restores, + ) + }, + ) + + eval_metrics = compute_metrics( + [counterfactual_outputs.logits], [labels] + ) + loss = calculate_loss(counterfactual_outputs.logits, labels) + loss_str = round(loss.item(), 2) + loss.backward() + optimizer.step() + scheduler.step() + intervenable.set_zero_grad() + total_step += 1 + + # eval + eval_labels = [] + eval_preds = [] + with torch.no_grad(): + torch.cuda.empty_cache() + for batch_dataset in D_test.batches(batch_size=batch_size): + # prepare base + base_inputs = batch_dataset.base.tokens + b_s = base_inputs["input_ids"].shape[0] + for k, v in base_inputs.items(): + if v is not None and isinstance(v, torch.Tensor): + base_inputs[k] = v.to(gpt2.device) + # prepare source + source_inputs = batch_dataset.source.tokens + for k, v in source_inputs.items(): + if v is not None and isinstance(v, torch.Tensor): + source_inputs[k] = v.to(gpt2.device) + # prepare label + labels = batch_dataset.patched_answer_tokens[:, 0].to(gpt2.device) + + assert all(x == 18 for x in batch_dataset.base.lengths) + assert all(x == 18 for x in batch_dataset.source.lengths) + _, counterfactual_outputs = intervenable( + {"input_ids": base_inputs["input_ids"]}, + [{"input_ids": source_inputs["input_ids"]}, {"input_ids": base_inputs["input_ids"]}], + { + "sources->base": ( + [[[aligning_pos]] * b_s]+[[[aligning_pos]] * b_s]*num_restores, + [[[aligning_pos]] * b_s]+[[[aligning_pos]] * b_s]*num_restores, + ) + }, + ) + eval_labels += [labels] + eval_preds += [counterfactual_outputs.logits] + eval_metrics = compute_metrics(eval_preds, eval_labels) + + data.append( + { + "pos": aligning_pos, + "layer": aligning_layer, + "acc": eval_metrics["accuracy"], + "kl_div": eval_metrics["kl_div"], + "stream": stream, + } + ) + + return data \ No newline at end of file