diff --git a/models/demos/whisper/README.md b/models/demos/whisper/README.md new file mode 100644 index 00000000000..b0de2134dbd --- /dev/null +++ b/models/demos/whisper/README.md @@ -0,0 +1,75 @@ +# Functional Whisper Model Demos For Audio Classification and Text Generation + +## Introduction + +Whisper is a pre-trained model for Automatic Speech Recognition (ASR) and Speech Translation. The models are trained on either English-only data or multilingual data. The English-only models were trained on the task of speech recognition. The multilingual models were trained on both Speech Recognition and Speech Translation tasks. + +The demos showcases Functional Whisper Model for Audio Classification and Text Generation tasks, +`sanchit-gandhi/whisper-medium-fleurs-lang-id` and `openai/whisper-tiny.en` versions from Hugging Face are utilized for respective tasks. + +### Details + +The entry point to the Functional Whisper model is the `whisper` function located in `ttnn_optimized_functional_whisper.py`. + +## Inputs + +Inputs by default are provided from `dataset/audio_classification` and `dataset/conditional_generation` folder. To modify the inputs or specify a different path, adjust the input_path parameter in the command accordingly. It's recommended to avoid direct modifications to the input_data.json file. + + +For the demos with datasets, Inputs for Audio classification are taken from `google/fleurs` dataset and Inputs for Conditional generation are taken from `hf-internal-testing/librispeech_asr_dummy` dataset. + +## Batch size: 8 + +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It is recommended to set the `batch_size` to 8. + +## How to run demo for Audio Classification task + +To run the demo for audio classification using the Whisper model, follow these instructions: + +- Use the following command to run the whisper for audio classification demo with ttnn optimized functional whisper: + ``` + `pytest --disable-warnings --input-path="models/demos/whisper/demo/dataset/audio_classification" models/demos/whisper/demo/demo.py::test_demo_for_audio_classification[8-models.demos.whisper.tt.ttnn_optimized_functional_whisper]` + ``` + +- to run the whisper for audio classification demo with ttnn functional whisper use the following command: + ``` + pytest --disable-warnings --input-path="models/demos/whisper/demo/dataset/audio_classification" models/demos/whisper/demo/demo.py::test_demo_for_audio_classification[8-8-models.demos.whisper.tt.ttnn_functional_whisper] + ``` + +- our another demo is designed to run with `google/fleurs` dataset for Audio classification, to run the demo for dataset use the command: + ``` + pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_audio_classification_dataset + ``` + +## How to run demo for Text Generation task +To run the demo for text generation using the Whisper model, follow these instructions: + +- Use the following command to run the whisper for text generation demo with ttnn optimized functional whisper: + ``` + pytest --disable-warnings --input-path="models/demos/whisper/demo/dataset/conditional_generation" models/demos/whisper/demo/demo.py::test_demo_for_conditional_generation[8-32-models.demos.whisper.tt.ttnn_optimized_functional_whisper] + ``` + +- Use the following command to run the whisper for text generation demo with ttnn functional whisper: + ``` + pytest --disable-warnings --input-path="models/demos/whisper/demo/dataset/conditional_generation" models/demos/whisper/demo/demo.py::test_demo_for_conditional_generation[8-32-models.demos.whisper.tt.ttnn_functional_whisper] + ``` + +- Our second demo is designed to run with `hf-internal-testing/librispeech_asr_dummy` dataset for text generation. + +- To run the second demo using ttnn optimized functional whisper with dataset inputs for 1 iteration(s), each configured with a batch size of 8 and decoding up to 32 tokens, use the following command : + ``` + pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_conditional_generation_dataset[8-1-32-models.demos.whisper.tt.ttnn_optimized_functional_whisper] + ``` +- To run the second demo using ttnn functional whisper with dataset inputs for 1 iteration(s), each configured with a batch size of 8 and decoding up to 32 tokens, use the following command: + ``` + pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_conditional_generation_dataset[8-1-32-models.demos.whisper.tt.ttnn_functional_whisper] + ``` + +## Results + +The demos presents a comprehensive view of the Whisper model's robustness in audio classification and text generation tasks. + +Audio classification predicts the languange of the provided audio sample and the demo using dataset inputs provides the accuracy of the model. +For example, accuracy of 0.75 is observed with `batch_size=8` and `n_iterations=3` + +In Text generation, the model predicts transcriptions in the same language as the audio (English). diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/10116516891483200485.wav b/models/demos/whisper/demo/dataset/audio_classification/10116516891483200485.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/10116516891483200485.wav rename to models/demos/whisper/demo/dataset/audio_classification/10116516891483200485.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/140291826269534354.wav b/models/demos/whisper/demo/dataset/audio_classification/140291826269534354.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/140291826269534354.wav rename to models/demos/whisper/demo/dataset/audio_classification/140291826269534354.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/1689242038473278354.wav b/models/demos/whisper/demo/dataset/audio_classification/1689242038473278354.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/1689242038473278354.wav rename to models/demos/whisper/demo/dataset/audio_classification/1689242038473278354.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/17340315164505628698.wav b/models/demos/whisper/demo/dataset/audio_classification/17340315164505628698.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/17340315164505628698.wav rename to models/demos/whisper/demo/dataset/audio_classification/17340315164505628698.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/17659141715436566244.wav b/models/demos/whisper/demo/dataset/audio_classification/17659141715436566244.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/17659141715436566244.wav rename to models/demos/whisper/demo/dataset/audio_classification/17659141715436566244.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/17928171511082320095.wav b/models/demos/whisper/demo/dataset/audio_classification/17928171511082320095.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/17928171511082320095.wav rename to models/demos/whisper/demo/dataset/audio_classification/17928171511082320095.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/2086639904747050008.wav b/models/demos/whisper/demo/dataset/audio_classification/2086639904747050008.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/2086639904747050008.wav rename to models/demos/whisper/demo/dataset/audio_classification/2086639904747050008.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/622196158886216764.wav b/models/demos/whisper/demo/dataset/audio_classification/622196158886216764.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/622196158886216764.wav rename to models/demos/whisper/demo/dataset/audio_classification/622196158886216764.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/7043619860143829064.wav b/models/demos/whisper/demo/dataset/audio_classification/7043619860143829064.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/7043619860143829064.wav rename to models/demos/whisper/demo/dataset/audio_classification/7043619860143829064.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/9522084197299278725.wav b/models/demos/whisper/demo/dataset/audio_classification/9522084197299278725.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/9522084197299278725.wav rename to models/demos/whisper/demo/dataset/audio_classification/9522084197299278725.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/11150113890463037787.wav b/models/demos/whisper/demo/dataset/conditional_generation/11150113890463037787.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/11150113890463037787.wav rename to models/demos/whisper/demo/dataset/conditional_generation/11150113890463037787.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/1298409023920250606.wav b/models/demos/whisper/demo/dataset/conditional_generation/1298409023920250606.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/1298409023920250606.wav rename to models/demos/whisper/demo/dataset/conditional_generation/1298409023920250606.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17566024285835266239.wav b/models/demos/whisper/demo/dataset/conditional_generation/17566024285835266239.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17566024285835266239.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17566024285835266239.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17646385371758249908.wav b/models/demos/whisper/demo/dataset/conditional_generation/17646385371758249908.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17646385371758249908.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17646385371758249908.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17659141715436566244.wav b/models/demos/whisper/demo/dataset/conditional_generation/17659141715436566244.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17659141715436566244.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17659141715436566244.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17928171511082320095.wav b/models/demos/whisper/demo/dataset/conditional_generation/17928171511082320095.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17928171511082320095.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17928171511082320095.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17938133003986293739.wav b/models/demos/whisper/demo/dataset/conditional_generation/17938133003986293739.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17938133003986293739.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17938133003986293739.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/2842775607363710885.wav b/models/demos/whisper/demo/dataset/conditional_generation/2842775607363710885.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/2842775607363710885.wav rename to models/demos/whisper/demo/dataset/conditional_generation/2842775607363710885.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/6757317816154782558.wav b/models/demos/whisper/demo/dataset/conditional_generation/6757317816154782558.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/6757317816154782558.wav rename to models/demos/whisper/demo/dataset/conditional_generation/6757317816154782558.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/6969469525741631060.wav b/models/demos/whisper/demo/dataset/conditional_generation/6969469525741631060.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/6969469525741631060.wav rename to models/demos/whisper/demo/dataset/conditional_generation/6969469525741631060.wav diff --git a/models/experimental/functional_whisper/demo/demo.py b/models/demos/whisper/demo/demo.py similarity index 64% rename from models/experimental/functional_whisper/demo/demo.py rename to models/demos/whisper/demo/demo.py index 79d364e070d..6b2ed721cee 100644 --- a/models/experimental/functional_whisper/demo/demo.py +++ b/models/demos/whisper/demo/demo.py @@ -18,9 +18,10 @@ from models.utility_functions import ( disable_compilation_reports, disable_persistent_kernel_cache, + profiler, ) -from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper -from models.generation_utils import get_logits_processor +from models.demos.whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper +from models.generation_utils import get_logits_processor, pad_input_32 from ttnn.model_preprocessing import preprocess_model_parameters import torch @@ -37,20 +38,6 @@ def load_input_paths(folder_path): return files -def pad_input_32(tensor, value): - len = tensor.shape[1] - - if len % 32 == 0: - return tensor - - padded_len = ((len // 32) + 1) * 32 - - pad_tensor = (value * torch.ones(tensor.shape[0], padded_len - len)).to(torch.long) - tensor = torch.cat([tensor, pad_tensor], dim=1) - - return tensor - - def run_generate( config, input_embeds, @@ -59,54 +46,59 @@ def run_generate( decoder_hidden_states, decoder_attention_mask, parameters, - processor, ttnn_linear_weight, device, + decoder_input_ids, generation_config, + batch_size, + max_tokens, ): - input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - logits_processor = get_logits_processor(input_ids, config) - - input_ids = pad_input_32(input_ids, config.pad_token_id).to(torch.long) - - decoder_start_values = generation_config.pad_token_id * torch.ones(1, 32).to(torch.long) + logits_processor = get_logits_processor(decoder_input_ids, config) + decoder_start_values = generation_config.pad_token_id * torch.ones(batch_size, input_features.shape[1]).to( + torch.long + ) + eos_reached = torch.zeros(batch_size, dtype=torch.bool) - for i in range(32): - output = ttnn_model.whisper( - config, - input_embeds, - decoder_hidden_states, + profiler.start(f"inference_time") + for i in range(max_tokens): + ttnn_output = ttnn_model.whisper_for_conditional_generation( + config=config, + input_embeds=input_embeds, + decoder_hidden_states=decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=parameters, + device=device, + ttnn_linear_weight=ttnn_linear_weight, ) - output = output @ ttnn_linear_weight - - output = ttnn.from_device(output) - - logits_to_torch = ttnn.to_torch(output) - + ttnn_output = ttnn.from_device(ttnn_output) + logits_to_torch = ttnn.to_torch(ttnn_output) next_token_logits = logits_to_torch[:, i, :] - next_tokens_scores = logits_processor(input_features, next_token_logits) - next_tokens = torch.argmax(next_tokens_scores, dim=-1) + next_tokens = torch.argmax(next_tokens_scores, dim=-1).unsqueeze(0) - if (i + 1) % 32 == 0: - input_ids = torch.cat([input_ids, decoder_start_values], dim=1) + # Check if EOS token is generated for any sample in the batch and + # Setting subsequent next_tokens to config.pad_token_id if EOS token is reached. + eos_generated_flags = next_tokens == config.eos_token_id + eos_reached = eos_reached | eos_generated_flags.squeeze(0) + next_tokens[:, eos_reached] = config.pad_token_id - input_ids[:, i + 1] = next_tokens[:, None] + if (i + 1) % 32 == 0: + decoder_input_ids = torch.cat([decoder_input_ids, decoder_start_values], dim=1) + decoder_input_ids[:, i + 1] = next_tokens[:, None] decoder_hidden_states, decoder_attention_mask = ttnn_model.preprocess_decoder_inputs( - config=config, input_ids=input_ids, attention_mask=None, parameters=parameters.decoder, device=device + config=config, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=parameters.decoder, + device=device, ) - if next_tokens == config.eos_token_id: + if torch.all(next_tokens == config.eos_token_id): break - logger.info(processor.batch_decode(input_ids, skip_special_tokens=True)[0]) - - ttnn_transcription = processor.batch_decode(input_ids, skip_special_tokens=True)[0] - return ttnn_transcription + profiler.end(f"inference_time") + return decoder_input_ids def run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs): @@ -164,16 +156,13 @@ def run_demo_functional_whisper_for_audio_classification_inference(input_path, t logger.info(predicted_label) -def run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs): - torch.manual_seed(0) - +def run_demo_functional_whisper_for_conditional_generation_inference( + input_path, ttnn_model, device, reset_seeds, batch_size=1, max_tokens=32 +): model = WhisperModel.from_pretrained("openai/whisper-tiny.en").to(torch.bfloat16).eval() - config = WhisperConfig.from_pretrained("openai/whisper-tiny.en") - processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe") hf_reference_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - linear_weight = hf_reference_model.proj_out.weight linear_weight = hf_reference_model.proj_out.weight ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) @@ -189,50 +178,69 @@ def run_demo_functional_whisper_for_conditional_generation_inference(input_path, device=device, ) - if len(input_data) < num_inputs: - assert False, "num_inputs exceeds number of audio files available in folder" - output_list = {} - for i in range(num_inputs): + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" + + for i in range(batch_size): input_file_path = input_data[i] samplerate, data = wavfile.read(input_file_path) inputs = feature_extractor(data, sampling_rate=samplerate, return_tensors="pt") dtype_to_use = torch.bfloat16 input_features = inputs.input_features.type(dtype_to_use) + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) - attention_mask = None + profiler.start(f"preprocessing_inputs") + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=parameters, + device=device, + ) + profiler.end(f"preprocessing_inputs") - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) + generation_config = hf_reference_model.generation_config + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + ) - generation_config = hf_reference_model.generation_config - ttnn_output = run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - processor=processor, - ttnn_linear_weight=ttnn_linear_weight, - device=device, - generation_config=generation_config, - ) - logger.info("Model Output") - logger.info(ttnn_output) - output_list[i] = ttnn_output - for i in range(len(output_list)): - logger.info(f"output for input {i+1}") - logger.info(output_list[i]) + profiler.start(f"post_processing_output_to_string") + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) + profiler.end(f"post_processing_output_to_string") + + logger.info("Model Output") + logger.info(ttnn_transcription) + + measurements = { + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements, ttnn_transcription def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device): @@ -287,16 +295,13 @@ def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, dev logger.info(predicted_label) -def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device): - torch.manual_seed(0) - +def run_demo_functional_whisper_for_conditional_generation_dataset( + ttnn_model, device, reset_seeds, batch_size=1, n_iterations=1, max_tokens=32 +): model = WhisperModel.from_pretrained("openai/whisper-tiny.en").to(torch.bfloat16).eval() - config = WhisperConfig.from_pretrained("openai/whisper-tiny.en") - processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe") hf_reference_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - linear_weight = hf_reference_model.proj_out.weight linear_weight = hf_reference_model.proj_out.weight ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) @@ -305,47 +310,64 @@ def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, d feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) - - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + sample = iter(ds) + batched_ground_truth_transcriptions = [] + + for _ in range(n_iterations): + for i in range(batch_size): + s = next(sample) + inputs = feature_extractor(s["audio"]["array"], sampling_rate=16000, return_tensors="pt") + ground_truth_transcriptions = s["text"] + dtype_to_use = torch.bfloat16 + input_features = inputs.input_features.type(dtype_to_use) + + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) + + batched_ground_truth_transcriptions.append(ground_truth_transcriptions) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.eval(), + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=device, + ) - attention_mask = None + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=parameters, + device=device, + ) - parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=hf_reference_model.generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + ) + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) + logger.info("Model Output") + logger.info(ttnn_transcription) - generation_config = hf_reference_model.generation_config - ttnn_output = run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - processor=processor, - ttnn_linear_weight=ttnn_linear_weight, - device=device, - generation_config=generation_config, - ) - logger.info("Model Output") - logger.info(ttnn_output) + return ttnn_transcription @pytest.mark.parametrize( @@ -364,16 +386,20 @@ def test_demo_for_audio_classification(input_path, ttnn_model, device, num_input @pytest.mark.parametrize( "ttnn_model", - (ttnn_optimized_functional_whisper, ttnn_functional_whisper), + (ttnn_functional_whisper, ttnn_optimized_functional_whisper), ) @pytest.mark.parametrize( - "num_inputs", - ((1),), + ("batch_size", "max_tokens"), + ((8, 32),), ) -def test_demo_for_conditional_generation(input_path, ttnn_model, device, num_inputs): +def test_demo_for_conditional_generation( + input_path, ttnn_model, device, use_program_cache, reset_seeds, batch_size, max_tokens +): disable_persistent_kernel_cache() disable_compilation_reports() - return run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs) + return run_demo_functional_whisper_for_conditional_generation_inference( + input_path, ttnn_model, device, reset_seeds, batch_size, max_tokens + ) @pytest.mark.parametrize( @@ -388,9 +414,20 @@ def test_demo_for_audio_classification_dataset(ttnn_model, device): @pytest.mark.parametrize( "ttnn_model", - (ttnn_functional_whisper, ttnn_optimized_functional_whisper), + ( + ttnn_functional_whisper, + ttnn_optimized_functional_whisper, + ), ) -def test_demo_for_conditional_generation_dataset(ttnn_model, device): +@pytest.mark.parametrize( + ("batch_size", "n_iterations", "max_tokens"), + ((8, 1, 32),), +) +def test_demo_for_conditional_generation_dataset( + ttnn_model, device, use_program_cache, reset_seeds, batch_size, n_iterations, max_tokens +): disable_persistent_kernel_cache() disable_compilation_reports() - return run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device) + return run_demo_functional_whisper_for_conditional_generation_dataset( + ttnn_model, device, reset_seeds, batch_size, n_iterations, max_tokens + ) diff --git a/models/experimental/functional_whisper/reference/torch_baseline_whisper.py b/models/demos/whisper/reference/torch_baseline_whisper.py similarity index 100% rename from models/experimental/functional_whisper/reference/torch_baseline_whisper.py rename to models/demos/whisper/reference/torch_baseline_whisper.py diff --git a/models/experimental/functional_whisper/reference/torch_functional_whisper.py b/models/demos/whisper/reference/torch_functional_whisper.py similarity index 100% rename from models/experimental/functional_whisper/reference/torch_functional_whisper.py rename to models/demos/whisper/reference/torch_functional_whisper.py diff --git a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py b/models/demos/whisper/tt/ttnn_functional_whisper.py similarity index 95% rename from models/experimental/functional_whisper/tt/ttnn_functional_whisper.py rename to models/demos/whisper/tt/ttnn_functional_whisper.py index 30b5fa712cf..8f1f2cb837d 100644 --- a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py +++ b/models/demos/whisper/tt/ttnn_functional_whisper.py @@ -198,8 +198,11 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device): + weights = ttnn.to_torch(parameters.embed_positions.weight) + inputs_embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(inputs_embeds, weights) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -399,8 +402,8 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters, device): + encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder, device=device) last_hidden_state = decoder( config, decoder_hidden_states, @@ -411,6 +414,25 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_conditional_generation( + config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters, device, ttnn_linear_weight +): + output = whisper( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ) + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py b/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py similarity index 95% rename from models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py rename to models/demos/whisper/tt/ttnn_optimized_functional_whisper.py index 56b8c0054f6..326777ef154 100644 --- a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py +++ b/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py @@ -195,8 +195,11 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device): + weights = ttnn.to_torch(parameters.embed_positions.weight) + inputs_embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(inputs_embeds, weights) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -396,8 +399,8 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters, device): + encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder, device=device) last_hidden_state = decoder( config, decoder_hidden_states, @@ -408,6 +411,25 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_conditional_generation( + config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters, device, ttnn_linear_weight +): + output = whisper( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ) + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/models/experimental/functional_whisper/README.md b/models/experimental/functional_whisper/README.md deleted file mode 100644 index 8a228b35e7e..00000000000 --- a/models/experimental/functional_whisper/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# ttnn_functional_whisper Demo - -## How to Run - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for audio classification. - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for audio classification. - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for conditional generation. - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for conditional generation. - -Our another demo is designed to run with `google/fleurs` for Audio classification and `hf-internal-testing/librispeech_asr_dummy` for Conditional generation - -Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification_dataset` to run audio classification demo with dataset input. - -Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation_dataset` to run conditional generation demo with dataset input. - -## Inputs - -Inputs by default are provided from `dataset/audio_classification` and `dataset/conditional_generation` folder. If you wish to change the inputs, provide a different path to demo. - -For demo with dataset,Inputs for Audio classification is taken from `google/fleurs` dataset and Inputs for Conditional generation is taken from `hf-internal-testing/librispeech_asr_dummy` dataset. - -## Details - -The entry point to whisper model is whisper in `models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py` for optimized version.(`models/experimental/functional_whisper/tt/ttnn_functional_whisper.py` for normal version). diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index b88669f43d9..03210ce37f1 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper +from models.demos.whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset import torch @@ -75,6 +75,7 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=parameters, + device=device, ) tt_output = ttnn.to_torch(tt_output) end = time.time() diff --git a/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py index 30c91e7f8fe..36e0479a3ae 100644 --- a/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper +from models.demos.whisper.reference import torch_functional_whisper import transformers from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py index 4281672c5fd..5557d0873ff 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py @@ -3,8 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper -from models.experimental.functional_whisper.tt import ttnn_functional_whisper +from models.demos.whisper.reference import torch_functional_whisper +from models.demos.whisper.tt import ttnn_functional_whisper import transformers from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset @@ -167,7 +167,7 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque device=device, ) - output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters, device=device) output = ttnn.from_device(output) output = ttnn.to_torch(output) @@ -370,6 +370,7 @@ def test_ttnn_whisper(device, ttnn_model): decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=ttnn_parameters, + device=device, ) last_hidden_state = ttnn.from_device(last_hidden_state) last_hidden_state = ttnn.to_torch(last_hidden_state) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index da793c1b4d7..a00556b06b8 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -3,8 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper -from models.experimental.functional_whisper.tt import ttnn_optimized_functional_whisper +from models.demos.whisper.reference import torch_functional_whisper +from models.demos.whisper.tt import ttnn_optimized_functional_whisper import transformers from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset @@ -166,7 +166,7 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) input_embeds = ttnn.to_device(input_embeds, device) - output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters, device=device) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, 0.968)