Skip to content

Commit

Permalink
Merge pull request #58 from stanfordnlp/zen/updatetestname
Browse files Browse the repository at this point in the history
[Minor] update test name
  • Loading branch information
frankaging authored Jan 17, 2024
2 parents 70e6cac + 80b12ee commit f46ab90
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..utils import *


class SubspaceInterventionWithTransformerTestCase(unittest.TestCase):
class ComplexInterventionWithGPT2TestCase(unittest.TestCase):
@classmethod
def setUpClass(self):
print("=== Test Suite: VanillaInterventionWithTransformerTestCase ===")
Expand Down Expand Up @@ -167,20 +167,20 @@ def test_lowrank_rotate_subspace_partition_in_forward_positive(self):
def suite():
suite = unittest.TestSuite()
suite.addTest(
SubspaceInterventionWithTransformerTestCase("test_clean_run_positive")
ComplexInterventionWithGPT2TestCase("test_clean_run_positive")
)
suite.addTest(
SubspaceInterventionWithTransformerTestCase(
ComplexInterventionWithGPT2TestCase(
"test_vanilla_subspace_partition_in_forward_positive"
)
)
suite.addTest(
SubspaceInterventionWithTransformerTestCase(
ComplexInterventionWithGPT2TestCase(
"test_rotate_subspace_partition_in_forward_positive"
)
)
suite.addTest(
SubspaceInterventionWithTransformerTestCase(
ComplexInterventionWithGPT2TestCase(
"test_lowrank_rotate_subspace_partition_in_forward_positive"
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from ..utils import *


class VanillaInterventionWithTransformerTestCase(unittest.TestCase):
class InterventionWithGPT2TestCase(unittest.TestCase):
@classmethod
def setUpClass(self):
print("=== Test Suite: VanillaInterventionWithTransformerTestCase ===")
print("=== Test Suite: InterventionWithGPT2TestCase ===")
self.config, self.tokenizer, self.gpt2 = create_gpt2_lm(
config=GPT2Config(
n_embd=24,
Expand Down Expand Up @@ -375,44 +375,44 @@ def test_with_location_broadcast_vanilla_intervention_positive(self):

def suite():
suite = unittest.TestSuite()
suite.addTest(VanillaInterventionWithTransformerTestCase("test_clean_run_positive"))
suite.addTest(InterventionWithGPT2TestCase("test_clean_run_positive"))
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_invalid_intervenable_unit_negative"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_single_position_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_multiple_position_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_complex_position_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_single_head_position_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_multiple_heads_positions_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_use_fast_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
InterventionWithGPT2TestCase(
"test_with_location_broadcast_vanilla_intervention_positive"
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from ..utils import *


class SubspaceInterventionWithMLPTestCase(unittest.TestCase):
class InterventionWithMLPTestCase(unittest.TestCase):
@classmethod
def setUpClass(self):
print("=== Test Suite: SubspaceInterventionWithMLPTestCase ===")
print("=== Test Suite: InterventionWithMLPTestCase ===")
self.config, self.tokenizer, self.mlp = create_mlp_classifier(
MLPConfig(
h_dim=3, n_layer=1, pdrop=0.0, num_classes=5,
Expand Down Expand Up @@ -302,17 +302,17 @@ def test_no_intervention_link_negative(self):

def suite():
suite = unittest.TestSuite()
suite.addTest(SubspaceInterventionWithMLPTestCase("test_clean_run_positive"))
suite.addTest(SubspaceInterventionWithMLPTestCase("test_with_subspace_positive"))
suite.addTest(SubspaceInterventionWithMLPTestCase("test_with_subspace_negative"))
suite.addTest(InterventionWithMLPTestCase("test_clean_run_positive"))
suite.addTest(InterventionWithMLPTestCase("test_with_subspace_positive"))
suite.addTest(InterventionWithMLPTestCase("test_with_subspace_negative"))
suite.addTest(
SubspaceInterventionWithMLPTestCase("test_intervention_link_positive")
InterventionWithMLPTestCase("test_intervention_link_positive")
)
suite.addTest(
SubspaceInterventionWithMLPTestCase("test_no_intervention_link_positive")
InterventionWithMLPTestCase("test_no_intervention_link_positive")
)
suite.addTest(
SubspaceInterventionWithMLPTestCase("test_no_intervention_link_negative")
InterventionWithMLPTestCase("test_no_intervention_link_negative")
)
return suite

Expand Down

0 comments on commit f46ab90

Please sign in to comment.