diff --git a/torchbenchmark/canary_models/hf_mixtral/__init__.py b/torchbenchmark/canary_models/hf_mixtral/__init__.py new file mode 100644 index 0000000000..6d91f8f86a --- /dev/null +++ b/torchbenchmark/canary_models/hf_mixtral/__init__.py @@ -0,0 +1,17 @@ +from torchbenchmark.tasks import NLP +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel + +class Model(HuggingFaceModel): + task = NLP.LANGUAGE_MODELING + # DEFAULT_TRAIN_BSIZE not specified since we're not implementing a train test + # DEFAULT_TRAIN_BSIZE = 1 + DEFAULT_EVAL_BSIZE = 1 + + def __init__(self, test, device, batch_size=None, extra_args=[]): + super().__init__(name="hf_mixtral", test=test, device=device, batch_size=batch_size, extra_args=extra_args) + + # # def train(self): + # # return NotImplementedError("Not implemented") + + # def eval(self): + # super().eval() \ No newline at end of file diff --git a/torchbenchmark/canary_models/hf_mixtral/install.py b/torchbenchmark/canary_models/hf_mixtral/install.py new file mode 100644 index 0000000000..64e5b1127e --- /dev/null +++ b/torchbenchmark/canary_models/hf_mixtral/install.py @@ -0,0 +1,13 @@ +import subprocess +import sys +import os +from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model + +def pip_install_requirements(): + subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) + +if __name__ == '__main__': + pip_install_requirements() + patch_transformers() + model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) + cache_model(model_name, trust_remote_code=True) \ No newline at end of file diff --git a/torchbenchmark/canary_models/hf_mixtral/metadata.yaml b/torchbenchmark/canary_models/hf_mixtral/metadata.yaml new file mode 100644 index 0000000000..19877db021 --- /dev/null +++ b/torchbenchmark/canary_models/hf_mixtral/metadata.yaml @@ -0,0 +1,11 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 1 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +train_benchmark: false +train_deterministic: false +not_implemented: + - device: NVIDIA A10G + # - device: cpu \ No newline at end of file diff --git a/torchbenchmark/canary_models/hf_mixtral/requirements.txt b/torchbenchmark/canary_models/hf_mixtral/requirements.txt new file mode 100644 index 0000000000..5d54ada424 --- /dev/null +++ b/torchbenchmark/canary_models/hf_mixtral/requirements.txt @@ -0,0 +1,3 @@ +bitsandbytes +transformers>=4.36.2 +numba \ No newline at end of file diff --git a/torchbenchmark/util/framework/huggingface/model_factory.py b/torchbenchmark/util/framework/huggingface/model_factory.py index 4d23251999..c32c4c54dc 100644 --- a/torchbenchmark/util/framework/huggingface/model_factory.py +++ b/torchbenchmark/util/framework/huggingface/model_factory.py @@ -29,6 +29,7 @@ 'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'), 'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'), 'hf_distil_whisper': (1024, 1024, 'AutoConfig.from_pretrained("distil-whisper/distil-medium.en")', 'AutoModelForAudioClassification'), + 'hf_mixtral' : (512,512, 'AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")', 'AutoModelForCausalLM'), # default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real 'llama_v2_7b_16h' : (128,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'), 'hf_MPT_7b_instruct': (512, 512, 'AutoConfig.from_pretrained("mosaicml/mpt-7b-instruct", trust_remote_code=True)', 'AutoModelForCausalLM'), @@ -39,7 +40,6 @@ # as per this page https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 trust_remote_code=True is not required 'mistral_7b_instruct' : (128, 128, 'AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")', 'AutoModelForCausalLM'), 'hf_Yi' : (512, 512, 'AutoConfig.from_pretrained("01-ai/Yi-6B", trust_remote_code=True)', 'AutoModelForCausalLM'), - } cpu_input_slice = {