From 9b7a91bee9d28f01327fc902dbc62d9ae8296daf Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 18:06:53 -0700 Subject: [PATCH 01/10] [Minor] Update dependency --- pyvene/models/intervenable_base.py | 21 +++++++++++++-------- requirements.txt | 3 +-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index fdf967f4..32241335 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1,5 +1,4 @@ import json, logging, torch, types -import nnsight import numpy as np from collections import OrderedDict from typing import List, Optional, Tuple, Union, Dict, Any @@ -27,6 +26,12 @@ from transformers.utils import ModelOutput from tqdm import tqdm, trange +try: + import nnsight +except: + print("nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.") + + @dataclass class IntervenableModelOutput(ModelOutput): """ @@ -226,7 +231,7 @@ def __init__(self, config, model, backend, **kwargs): # cached swapped activations (hot) self.hot_activations = {} - self.aux_loss = [] + self.full_intervention_outputs = [] # temp fields should not be accessed outside self._batched_setter_activation_select = {} @@ -1558,16 +1563,16 @@ def hook_callback(model, args, kwargs, output=None): else: if not isinstance(self.interventions[key][0], types.FunctionType): if intervention.is_source_constant: - intervened_representation = do_intervention( + raw_intervened_representation = do_intervention( selected_output, None, intervention, subspaces[key_i] if subspaces is not None else None, ) - if isinstance(intervened_representation, InterventionOutput): - if intervened_representation.loss is not None: - self.aux_loss.append(intervened_representation.loss) - intervened_representation = intervened_representation.output + if isinstance(raw_intervened_representation, InterventionOutput): + # memorize for other training objectives + self.full_intervention_outputs.append(raw_intervened_representation) + intervened_representation = raw_intervened_representation.output else: intervened_representation = do_intervention( selected_output, @@ -1866,7 +1871,7 @@ def forward( if sources is not None and not isinstance(sources, list): sources = [sources] - self.aux_loss.clear() + self.full_intervention_outputs.clear() self._cleanup_states() diff --git a/requirements.txt b/requirements.txt index f1518c3b..32a2f9a8 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=2.0.0 -transformers==4.40.2 +transformers==4.44.0 datasets>=2.16.1 protobuf>=3.20.0 matplotlib>=3.7.4 @@ -10,4 +10,3 @@ numpy>=1.23.5 fsspec>=2023.6.0 accelerate>=0.29.1 sentencepiece>=0.1.96 -nnsight>=0.1.0 From 1e1b09b190e4587784e7e0ffa158734980a32b1a Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 18:08:55 -0700 Subject: [PATCH 02/10] bump up version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a057fcd7..f3d2d3a9 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="pyvene", - version="0.1.4", + version="0.1.5", description="Use Activation Intervention to Interpret Causal Mechanism of Model", long_description=long_description, long_description_content_type='text/markdown', From 86a41cafa30d0c0cb924302ce04450594bdea5e5 Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 18:47:27 -0700 Subject: [PATCH 03/10] update --- pyvene/models/intervenable_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 32241335..60a476f7 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1573,6 +1573,8 @@ def hook_callback(model, args, kwargs, output=None): # memorize for other training objectives self.full_intervention_outputs.append(raw_intervened_representation) intervened_representation = raw_intervened_representation.output + else: + intervened_representation = raw_intervened_representation else: intervened_representation = do_intervention( selected_output, From 969ae2778a7f05acfb789bd17ac2f78a5ede559a Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 19:04:32 -0700 Subject: [PATCH 04/10] update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 32a2f9a8..7f491f48 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=2.0.0 -transformers==4.44.0 +transformers>=4.40.0 datasets>=2.16.1 protobuf>=3.20.0 matplotlib>=3.7.4 From 95e5122db3aac34bb103b5f98bd2b21b4801c5ad Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 19:30:29 -0700 Subject: [PATCH 05/10] update --- pyvene/models/intervenable_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 60a476f7..db84b57d 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1570,7 +1570,6 @@ def hook_callback(model, args, kwargs, output=None): subspaces[key_i] if subspaces is not None else None, ) if isinstance(raw_intervened_representation, InterventionOutput): - # memorize for other training objectives self.full_intervention_outputs.append(raw_intervened_representation) intervened_representation = raw_intervened_representation.output else: From 3c407cfba9f5e26caedf320bf6421275b1b7e2ac Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 20:51:07 -0700 Subject: [PATCH 06/10] update --- pyvene/models/intervenable_base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index db84b57d..32aaa525 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1563,17 +1563,17 @@ def hook_callback(model, args, kwargs, output=None): else: if not isinstance(self.interventions[key][0], types.FunctionType): if intervention.is_source_constant: - raw_intervened_representation = do_intervention( + intervened_representation = do_intervention( selected_output, None, intervention, subspaces[key_i] if subspaces is not None else None, ) - if isinstance(raw_intervened_representation, InterventionOutput): - self.full_intervention_outputs.append(raw_intervened_representation) - intervened_representation = raw_intervened_representation.output - else: - intervened_representation = raw_intervened_representation + # if isinstance(raw_intervened_representation, InterventionOutput): + # self.full_intervention_outputs.append(raw_intervened_representation) + # intervened_representation = raw_intervened_representation.output + # else: + # intervened_representation = raw_intervened_representation else: intervened_representation = do_intervention( selected_output, From eb4efe062b6f4854271a367dd665ab331204dac4 Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 20:57:41 -0700 Subject: [PATCH 07/10] update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7f491f48..32a2f9a8 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=2.0.0 -transformers>=4.40.0 +transformers==4.44.0 datasets>=2.16.1 protobuf>=3.20.0 matplotlib>=3.7.4 From 3cefc8b89081d816b917895a1c8041e729d65fa8 Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 21:16:13 -0700 Subject: [PATCH 08/10] update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 32a2f9a8..5292d6d8 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=2.0.0 -transformers==4.44.0 +transformers==4.44.2 datasets>=2.16.1 protobuf>=3.20.0 matplotlib>=3.7.4 From 46c827ec68f2626ed651682f14f15a6068202f2c Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 22:19:10 -0700 Subject: [PATCH 09/10] update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5292d6d8..b52bae79 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=2.0.0 -transformers==4.44.2 +transformers==4.40.2 datasets>=2.16.1 protobuf>=3.20.0 matplotlib>=3.7.4 From 4b14b6e9c546f6c2a9cc59c7af0c8b8d8129b970 Mon Sep 17 00:00:00 2001 From: frankaging Date: Fri, 23 Aug 2024 22:38:11 -0700 Subject: [PATCH 10/10] update --- pyvene/models/intervenable_base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 32aaa525..db84b57d 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1563,17 +1563,17 @@ def hook_callback(model, args, kwargs, output=None): else: if not isinstance(self.interventions[key][0], types.FunctionType): if intervention.is_source_constant: - intervened_representation = do_intervention( + raw_intervened_representation = do_intervention( selected_output, None, intervention, subspaces[key_i] if subspaces is not None else None, ) - # if isinstance(raw_intervened_representation, InterventionOutput): - # self.full_intervention_outputs.append(raw_intervened_representation) - # intervened_representation = raw_intervened_representation.output - # else: - # intervened_representation = raw_intervened_representation + if isinstance(raw_intervened_representation, InterventionOutput): + self.full_intervention_outputs.append(raw_intervened_representation) + intervened_representation = raw_intervened_representation.output + else: + intervened_representation = raw_intervened_representation else: intervened_representation = do_intervention( selected_output,