From becd113c1c6e02e8b77a142d33f9e7910fd840a3 Mon Sep 17 00:00:00 2001 From: frankaging Date: Tue, 26 Mar 2024 00:27:16 -0700 Subject: [PATCH 1/3] [Minor] Accepting field for loss calculation --- pyvene/models/intervenable_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 1033cbaa..7c4dcd5c 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1317,6 +1317,7 @@ def forward( unit_locations: Optional[Dict] = None, source_representations: Optional[Dict] = None, subspaces: Optional[List] = None, + labels: Optional[torch.LongTensor] = None, output_original_output: Optional[bool] = False, return_dict: Optional[bool] = None, ): @@ -1438,7 +1439,7 @@ def forward( ) # run intervened forward - counterfactual_outputs = self.model(**base) + counterfactual_outputs = self.model(**base, labels=labels) set_handlers_to_remove.remove() self._output_validation() From 1d924c1693cd2b7b0572e5e066f55c5806a83139 Mon Sep 17 00:00:00 2001 From: frankaging Date: Tue, 26 Mar 2024 00:37:16 -0700 Subject: [PATCH 2/3] minor adjust --- pyvene/models/intervenable_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 7c4dcd5c..f1426648 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1439,7 +1439,10 @@ def forward( ) # run intervened forward - counterfactual_outputs = self.model(**base, labels=labels) + if labels is not None: + counterfactual_outputs = self.model(**base, labels=labels) + else: + counterfactual_outputs = self.model(**base) set_handlers_to_remove.remove() self._output_validation() From 9c5a2ffda9e4ac9feab1b2d4311ce9f08253c535 Mon Sep 17 00:00:00 2001 From: frankaging Date: Tue, 26 Mar 2024 00:45:29 -0700 Subject: [PATCH 3/3] bump up version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 07692470..506b45d3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="pyvene", - version="0.0.8dev", + version="0.0.8", description="Use Activation Intervention to Interpret Causal Mechanism of Model", long_description=long_description, long_description_content_type='text/markdown',