diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index e8e8af1ec5..a450260fc6 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -5,6 +5,7 @@ from . import megatronlm from . import textsynth from . import dummy +from . import modalities MODEL_REGISTRY = { "hf": gpt2.HFLM, @@ -17,6 +18,7 @@ "megatronlm": megatronlm.MegatronLMClient, "textsynth": textsynth.TextSynthLM, "dummy": dummy.DummyLM, + "modalities": modalities.Modalities } diff --git a/lm_eval/models/huggingface.py b/lm_eval/models/huggingface.py index 2dca98a1a7..825d707d38 100644 --- a/lm_eval/models/huggingface.py +++ b/lm_eval/models/huggingface.py @@ -91,6 +91,7 @@ def __init__( load_in_4bit: Optional[bool] = False, trust_remote_code: Optional[bool] = False, gptq_use_triton: Optional[bool] = False, + **kwargs ): """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation. Args: @@ -214,6 +215,7 @@ def __init__( load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, **model_kwargs, + **kwargs ) # note: peft_path can be different than pretrained model path if peft is not None: @@ -251,6 +253,7 @@ def _create_auto_model( trust_remote_code: Optional[bool] = False, torch_dtype: Optional[Union[str, torch.dtype]] = None, gptq_use_triton: Optional[bool] = False, + **kwargs ) -> transformers.AutoModel: """Returns a pre-trained pytorch model from a pre-trained model configuration.""" if not quantized: @@ -271,6 +274,7 @@ def _create_auto_model( trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, **model_kwargs, + **kwargs ) else: from auto_gptq import AutoGPTQForCausalLM diff --git a/lm_eval/models/modalities.py b/lm_eval/models/modalities.py new file mode 100644 index 0000000000..d0246fded1 --- /dev/null +++ b/lm_eval/models/modalities.py @@ -0,0 +1,22 @@ + +from typing import Union, List, Optional +import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, BatchEncoding +from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapterConfig, HFModelAdapter +from .huggingface import AutoCausalLM + +TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] + + +class Modalities(AutoCausalLM): + def __init__(self, *args, **kwargs): + AutoConfig.register("modalities", HFModelAdapterConfig) + AutoModelForCausalLM.register(HFModelAdapterConfig, HFModelAdapter) + # TODO load our own tokenizer + super().__init__(tokenizer="gpt2", *args, **kwargs) + + + def _model_call( + self, inputs: TokenSequence, labels: Optional[TokenSequence] = None + ) -> TokenSequence: + return self.model(inputs) diff --git a/tests/test_models.py b/tests/test_models.py index c50332da6e..ae595319db 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -54,9 +54,26 @@ ] -# Test HuggingFace Models (GPT-2) +def test_modalities(): + # dismiss sequences that are too long for our test checkpoint + test_cases = LOGLIKELIHOOD_TEST_CASES[:5] + modalities = models.get_model("modalities").create_from_arg_string("pretrained='testdata/models/modalities/checkpoint'") + results = modalities.loglikelihood(test_cases) + for loglikelihood, is_max_loglikelihood in results: + assert type(loglikelihood) == float + assert type(is_max_loglikelihood) == bool + + # test empty context + modalities.loglikelihood([("", "test")]) + greedy_len = 20 + gen = modalities.greedy_until( + [("The quick brown fox jumps over the lazy", {"until": [".", "\n"], "max_length": greedy_len})] + )[0] + assert type(gen) == str + assert len(gen.split()) == greedy_len +# Test HuggingFace Models (GPT-2) def test_gpt2(): gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu") ( @@ -79,7 +96,7 @@ def test_gpt2(): gpt2.loglikelihood([("", "test")]) (gen,) = gpt2.greedy_until( - [("The quick brown fox jumps over the lazy", [".", "\n"])] + [("The quick brown fox jumps over the lazy", {"until": [".", "\n"]})] ) assert gen == ", lazy fox and they both fall to the ground" diff --git a/tests/testdata/models/modalities/checkpoint/config.json b/tests/testdata/models/modalities/checkpoint/config.json new file mode 100644 index 0000000000..515fb9f9fc --- /dev/null +++ b/tests/testdata/models/modalities/checkpoint/config.json @@ -0,0 +1 @@ +{"config": {"sample_key": "input_ids", "prediction_key": "logits", "block_size": 128, "vocab_size": 50304, "n_layer": 1, "n_head": 1, "n_embd": 128, "ffn_hidden": 128, "dropout": 0.0, "bias": true, "attention": {"attention_type": "pytorch_flash_attention", "scaling_factor": 3}, "activation": "gelu", "epsilon": 1e-05, "weight_init": {"mean": 0.0, "std": 0.02}}, "model_type": "modalities_gpt2"} \ No newline at end of file diff --git a/tests/testdata/models/modalities/checkpoint/pytorch_model.bin b/tests/testdata/models/modalities/checkpoint/pytorch_model.bin new file mode 100644 index 0000000000..37aeb8650a Binary files /dev/null and b/tests/testdata/models/modalities/checkpoint/pytorch_model.bin differ