Skip to content

Commit

Permalink
Make generation tests generic
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Sep 16, 2024
1 parent 7ece098 commit 93bee1a
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 17 deletions.
8 changes: 4 additions & 4 deletions tests/composition/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def test_parallel_generate(self):
seq_output_length = 32

# Finally, also check if generation works properly
if self.is_speech_model:
input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"]
else:
input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"]
input_ids = self.extract_input_ids(
self.get_input_samples(self.generate_input_samples_shape, config=model1.config)
)

input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertLessEqual(generated.shape, (2, seq_output_length))
Expand Down
8 changes: 4 additions & 4 deletions tests/methods/test_compacter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def test_compacter_generate(self):
seq_output_length = 32

# Finally, also check if generation works properly
if self.is_speech_model:
input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"]
else:
input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"]
input_ids = self.extract_input_ids(
self.get_input_samples(self.generate_input_samples_shape, config=model1.config)
)

input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertLessEqual(generated.shape, (1, seq_output_length))
8 changes: 4 additions & 4 deletions tests/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def test_prefix_tuning_generate(self):
seq_output_length = 32

# Finally, also check if generation works properly
if self.is_speech_model:
input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"]
else:
input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"]
input_ids = self.extract_input_ids(
self.get_input_samples(self.generate_input_samples_shape, config=model1.config)
)

input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertLessEqual(generated.shape, (1, seq_output_length))
9 changes: 9 additions & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
class AdapterTestBase:
# If not overriden by subclass, AutoModel should be used.
model_class = AutoAdapterModel
tokenizer_name = "tests/fixtures/SiBERT"
# Default shape of inputs to use
default_input_samples_shape = (3, 64)
generate_input_samples_shape = (1, 4)
leave_out_layers = [0, 1]
do_run_train_tests = True
# default arguments for test_adapter_heads
Expand Down Expand Up @@ -98,6 +100,9 @@ def assert_adapter_unavailable(self, model, adapter_name):
self.assertFalse(adapter_name in model.adapters_config)
self.assertEqual(len(model.get_adapter(adapter_name)), 0)

def extract_input_ids(self, inputs):
return inputs["input_ids"]


class VisionAdapterTestBase(AdapterTestBase):
default_input_samples_shape = (3, 3, 224, 224)
Expand Down Expand Up @@ -146,10 +151,14 @@ class SpeechAdapterTestBase(AdapterTestBase):
"""Base class for speech adapter tests."""

default_input_samples_shape = (3, 80, 3000) # (batch_size, n_mels, enc_seq_len)
generate_input_samples_shape = (1, 80, 3000)
is_speech_model = True # Flag for tests to determine if the model is a speech model due to input format difference
time_window = 3000 # Time window for audio samples
seq_length = 80

def extract_input_ids(self, inputs):
return inputs["input_features"]

def add_head(self, model, name, head_type="seq2seq_lm", **kwargs):
"""Adds a head to the model."""
if head_type == "audio_classification":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adapter_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,6 @@ def _instantiate_tokenizer(self, model):
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
input_data = self.get_input_samples(config=self.config())
else:
tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT")
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
input_data = self.get_input_samples((1, 128), vocab_size=tokenizer.vocab_size, config=model.config)
return tokenizer, input_data
6 changes: 2 additions & 4 deletions tests/test_adapter_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ def test_seq2seq_lm_head(self):

# Finally, also check if generation works properly
input_shape = self._get_input_shape()
if self.is_speech_model:
input_ids = self.get_input_samples(input_shape, config=model1.config)["input_features"]
else:
input_ids = self.get_input_samples(input_shape, config=model1.config)["input_ids"]
input_ids = self.extract_input_ids(self.get_input_samples(input_shape, config=model1.config))

input_ids = input_ids.to(torch_device)
# Use a different length for the seq2seq output
seq_output_length = self.seq_length + 30
Expand Down

0 comments on commit 93bee1a

Please sign in to comment.