Skip to content

Commit

Permalink
Added distil_whisper. (#2074)
Browse files Browse the repository at this point in the history
Summary:
Should work! Added distil whisper to torchbench. Local time is within normal CI requirements.
![image](https://github.com/pytorch/benchmark/assets/24942306/0f7e7cae-ae90-44e6-bbff-4eede4a4f730)

Pull Request resolved: #2074

Reviewed By: aaronenyeshi

Differential Revision: D52053930

Pulled By: xuzhao9

fbshipit-source-id: f0d20a821c5de916ca174b0033d9d79b5d6cafa0
  • Loading branch information
MaanavD authored and facebook-github-bot committed Dec 12, 2023
1 parent 075f804 commit d61b74a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 0 deletions.
30 changes: 30 additions & 0 deletions torchbenchmark/models/hf_distil_whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torchbenchmark.tasks import SPEECH
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel
import torch

class Model(HuggingFaceModel):
task = SPEECH.RECOGNITION
DEFAULT_TRAIN_BSIZE = 8
DEFAULT_EVAL_BSIZE = 1

def __init__(self, test, device, batch_size=None, extra_args=[]):
if test == "train":
raise NotImplementedError("Training is not implemented.")
super().__init__(name="hf_distil_whisper", test=test, device=device, batch_size=batch_size, extra_args=extra_args)
self.feature_size = 80
self.sequence_length = 3000
self.input_features = torch.randn(size=(self.batch_size, self.feature_size, self.sequence_length),device=self.device)
self.example_inputs = {"input_features": self.input_features.to(self.device), "input_ids" : self.input_features.to(self.device)}
self.model.to(self.device)

def train(self):
raise NotImplementedError("Training is not implemented.")

def eval(self):
self.model.eval()
with torch.no_grad():
self.model(self.example_inputs["input_ids"])

def enable_fp16(self):
self.model.half()
self.example_inputs = {"input_features": self.input_features.half().to(self.device), "input_ids" : self.input_features.half().to(self.device)}
14 changes: 14 additions & 0 deletions torchbenchmark/models/hf_distil_whisper/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

import subprocess
import sys
import os
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model

def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

if __name__ == '__main__':
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name)
8 changes: 8 additions & 0 deletions torchbenchmark/models/hf_distil_whisper/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 16
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
2 changes: 2 additions & 0 deletions torchbenchmark/models/hf_distil_whisper/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sentencepiece
datasets
2 changes: 2 additions & 0 deletions torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# see https://huggingface.co/bert-large-cased
'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'),
'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'),
'hf_distil_whisper': (1024, 1024, 'AutoConfig.from_pretrained("distil-whisper/distil-medium.en")', 'AutoModelForAudioClassification'),
# default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real
'llama_v2_7b_16h' : (128,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'),
'hf_MPT_7b_instruct': (512, 512, 'AutoConfig.from_pretrained("mosaicml/mpt-7b-instruct", trust_remote_code=True)', 'AutoModelForCausalLM'),
Expand All @@ -36,6 +37,7 @@
'llama_v2_70b' : (512, 512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")', 'AutoModelForMaskedLM'),
'phi_1_5' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)', 'AutoModelForCausalLM'),
'hf_Yi' : (512, 512, 'AutoConfig.from_pretrained("01-ai/Yi-6B", trust_remote_code=True)', 'AutoModelForCausalLM'),

}

cpu_input_slice = {
Expand Down

0 comments on commit d61b74a

Please sign in to comment.