diff --git a/models/experimental/functional_whisper/README.md b/models/experimental/functional_whisper/README.md index 8a228b35e7e..73ee452b1ff 100644 --- a/models/experimental/functional_whisper/README.md +++ b/models/experimental/functional_whisper/README.md @@ -1,20 +1,17 @@ -# ttnn_functional_whisper Demo +--- -## How to Run +# Functional Whisper Model Demos For Audio Classification and Text Generation -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for audio classification. +## Introduction -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for audio classification. +Whisper is a pre-trained model for automatic speech recognition (ASR) and speech translation.The models are trained on either English-only data or multilingual data. The English-only models were trained on the task of speech recognition. The multilingual models were trained on both speech recognition and speech translation tasks. -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for conditional generation. +The demos showcases the Functional Whisper Model for Audio Classification and Text Generation tasks, +`sanchit-gandhi/whisper-medium-fleurs-lang-id` and `openai/whisper-tiny.en` versions Hugging Face are utilized respective tasks. -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for conditional generation. +### Details -Our another demo is designed to run with `google/fleurs` for Audio classification and `hf-internal-testing/librispeech_asr_dummy` for Conditional generation - -Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification_dataset` to run audio classification demo with dataset input. - -Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation_dataset` to run conditional generation demo with dataset input. +The entry point to the Functional Whisper model is the `whisper` function located in `ttnn_optimized_functional_whisper.py`. ## Inputs @@ -22,6 +19,55 @@ Inputs by default are provided from `dataset/audio_classification` and `dataset/ For demo with dataset,Inputs for Audio classification is taken from `google/fleurs` dataset and Inputs for Conditional generation is taken from `hf-internal-testing/librispeech_asr_dummy` dataset. -## Details +## Batch size: 8 + +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the `batch_size` to 8 + +## How to run demo for Audio Classification task + +To run the demo for audio classification using the Whisper model, follow these instructions: + +- Use the following command to run the whisper for audio classification demo with ttnn optimized functional whisper: + ``` + `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[8-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` + ``` + +- to run the whisper for audio classification demo with ttnn functional whisper use the following command: + ``` + pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[8-models.experimental.functional_whisper.tt.ttnn_functional_whisper] + ``` + +- our another demo is designed to run with `google/fleurs` dataset for Audio classification, to run the demo for dataset use the command: + ``` + pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification_dataset + ``` + +## How to run demo for Text Generation task +To run the demo for text generation using the Whisper model, follow these instructions: + +- Use the following command to run the whisper for text generation demo with ttnn optimized functional whisper: + ``` + `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` + ``` + +- Use the following command to run the whisper for text generation demo with ttnn functional whisper: + ``` + pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper] + ``` + +- our another demo is designed to run with `hf-internal-testing/librispeech_asr_dummy` for text generation, to run the demo for dataset use the command: + ``` + pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation_dataset + ``` + +## Results + +The text generation demo presents a comprehensive view of the Whisper model's robustness in audio classification and text generation tasks. + +Audio classification predicts the languange of the provided audio sample and dataset demo +also provides the accuracy of the model. +for example `batch_size=8` and `n_iterations=3` gives an accuracy of 0.75 + +For Text generation, the model predicts transcriptions in the same language as the audio (English). -The entry point to whisper model is whisper in `models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py` for optimized version.(`models/experimental/functional_whisper/tt/ttnn_functional_whisper.py` for normal version). +--- diff --git a/models/experimental/functional_whisper/demo/demo.py b/models/experimental/functional_whisper/demo/demo.py index 79d364e070d..fe850f51b28 100644 --- a/models/experimental/functional_whisper/demo/demo.py +++ b/models/experimental/functional_whisper/demo/demo.py @@ -30,6 +30,7 @@ from transformers import AutoFeatureExtractor, WhisperForAudioClassification from datasets import load_dataset +from sklearn.metrics import accuracy_score def load_input_paths(folder_path): @@ -109,9 +110,9 @@ def run_generate( return ttnn_transcription -def run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs): - torch.manual_seed(1234) - +def run_demo_functional_whisper_for_audio_classification_inference( + reset_seeds, input_path, ttnn_model, device, batch_size +): feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") @@ -124,10 +125,11 @@ def run_demo_functional_whisper_for_audio_classification_inference(input_path, t custom_preprocessor=ttnn_model.custom_preprocessor, device=device, ) - if len(input_data) < num_inputs: - assert False, "num_inputs exceeds number of audio files available in folder" + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" - for i in range(num_inputs): + batched_inputs = [] + for i in range(batch_size): input_file_path = input_data[i] samplerate, data = wavfile.read(input_file_path) @@ -138,30 +140,33 @@ def run_demo_functional_whisper_for_audio_classification_inference(input_path, t ) input_features = inputs.input_features + if i == 0: + batched_inputs = input_features + else: + batched_inputs = torch.cat((batched_inputs, input_features), dim=0) - config = model.config - input_embedding = ttnn_model.preprocess_encoder_inputs( - input_features=input_features, parameters=parameters.encoder, device=device - ) - - encoder_outputs = ttnn_model.encoder( - config=config, inputs_embeds=input_embedding, parameters=parameters.encoder - ) - - hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight) - hidden_states = ttnn.add(hidden_states, parameters.projector.bias) - - pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + config = model.config + input_embedding = ttnn_model.preprocess_encoder_inputs( + input_features=batched_inputs, parameters=parameters.encoder, device=device + ) - logits = ttnn.matmul(pooled_output, parameters.classifier.weight) - logits = ttnn.add(logits, parameters.classifier.bias) + out_logits = ttnn_model.whisper_for_audio_classification( + config=config, + inputs_embeds=input_embedding, + parameters=parameters, + device=device, + batch_size=batch_size, + ) - logits_torch = ttnn.to_torch(logits) - predicted_class_ids = torch.argmax(logits_torch).item() + logits_torch = ttnn.to_torch(out_logits) + predicted_list = [] + for i in range(batch_size): + single_logits_torch = logits_torch[i].squeeze(0) + predicted_class_ids = torch.argmax(single_logits_torch).item() predicted_label = model.config.id2label[predicted_class_ids] - - logger.info("predicted_label") - logger.info(predicted_label) + logger.info(f"predicted_label: {predicted_label}") + predicted_list.append(predicted_label) + return predicted_list def run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs): @@ -235,28 +240,19 @@ def run_demo_functional_whisper_for_conditional_generation_inference(input_path, logger.info(output_list[i]) -def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device): - torch.manual_seed(1234) - +def run_demo_functional_whisper_for_audio_classification_dataset( + reset_seeds, ttnn_model, device, batch_size=8, n_iterations=1 +): feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") model.eval() - ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) - sample = next(iter(ds)) - - inputs = feature_extractor( - sample["audio"]["array"], - sampling_rate=sample["audio"]["sampling_rate"], - return_tensors="pt", - ) - - input_features = inputs.input_features - - logger.debug("Input audio language:") - logger.debug(sample["language"]) + ds_iter = iter(ds) + reference_labels = [] + predicted_labels = [] + config = model.config parameters = preprocess_model_parameters( initialize_model=lambda: model, convert_to_ttnn=ttnn_model.convert_to_ttnn, @@ -264,27 +260,50 @@ def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, dev device=device, ) - config = model.config - input_embedding = ttnn_model.preprocess_encoder_inputs( - input_features=input_features, parameters=parameters.encoder, device=device - ) - - encoder_outputs = ttnn_model.encoder(config=config, inputs_embeds=input_embedding, parameters=parameters.encoder) - - hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight) - hidden_states = ttnn.add(hidden_states, parameters.projector.bias) - - pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + for _ in range(n_iterations): + batch_input = [] + # prepare the batched audio inputs + for bs in range(batch_size): + sample = next(ds_iter) + inputs = feature_extractor( + sample["audio"]["array"], + sampling_rate=sample["audio"]["sampling_rate"], + return_tensors="pt", + ) + input_features = inputs.input_features + if bs == 0: + batch_input = input_features + else: + batch_input = torch.cat((batch_input, input_features), dim=0) + reference_labels.append(sample["language"]) + + # preprocess the inputs + input_embedding = ttnn_model.preprocess_encoder_inputs( + input_features=batch_input, parameters=parameters.encoder, device=device + ) - logits = ttnn.matmul(pooled_output, parameters.classifier.weight) - logits = ttnn.add(logits, parameters.classifier.bias) + # run the model + out_logits = ttnn_model.whisper_for_audio_classification( + config=config, + inputs_embeds=input_embedding, + parameters=parameters, + device=device, + batch_size=batch_size, + ) - logits_torch = ttnn.to_torch(logits) - predicted_class_ids = torch.argmax(logits_torch).item() - predicted_label = model.config.id2label[predicted_class_ids] + # postprocessing the outputs + logits_torch = ttnn.to_torch(out_logits) + for i in range(batch_size): + single_logits_torch = logits_torch[i].squeeze(0) + predicted_class_ids = torch.argmax(single_logits_torch).item() + predicted_label = model.config.id2label[predicted_class_ids] + predicted_labels.append(predicted_label) - logger.info("predicted_label") - logger.info(predicted_label) + accuracy = accuracy_score(reference_labels, predicted_labels) + logger.info(f"reference labels: {reference_labels}") + logger.info(f"predicted labels: {predicted_labels}") + logger.info(f"Accuracy: {accuracy}") + return accuracy def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device): @@ -353,13 +372,15 @@ def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, d (ttnn_optimized_functional_whisper, ttnn_functional_whisper), ) @pytest.mark.parametrize( - "num_inputs", - ((1),), + "batch_size", + ((8),), ) -def test_demo_for_audio_classification(input_path, ttnn_model, device, num_inputs): +def test_demo_for_audio_classification(reset_seeds, input_path, ttnn_model, device, batch_size): disable_persistent_kernel_cache() disable_compilation_reports() - return run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs) + return run_demo_functional_whisper_for_audio_classification_inference( + reset_seeds, input_path, ttnn_model, device, batch_size + ) @pytest.mark.parametrize( @@ -380,10 +401,20 @@ def test_demo_for_conditional_generation(input_path, ttnn_model, device, num_inp "ttnn_model", (ttnn_optimized_functional_whisper, ttnn_functional_whisper), ) -def test_demo_for_audio_classification_dataset(ttnn_model, device): +@pytest.mark.parametrize( + "batch_size", + ((8),), +) +@pytest.mark.parametrize( + "n_iterations", + ((5),), +) +def test_demo_for_audio_classification_dataset(reset_seeds, ttnn_model, device, batch_size, n_iterations): disable_persistent_kernel_cache() disable_compilation_reports() - return run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device) + return run_demo_functional_whisper_for_audio_classification_dataset( + reset_seeds, ttnn_model, device, batch_size=batch_size, n_iterations=n_iterations + ) @pytest.mark.parametrize( diff --git a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py index 30b5fa712cf..6060a507318 100644 --- a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py @@ -198,8 +198,14 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device, batch_size): + # issue #7872 + # broadcast add is not happening for batched inputs + # hidden_states = inputs_embeds + parameters.embed_positions.weight + weights = ttnn.to_torch(parameters.embed_positions.weight) + embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(weights, embeds) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -399,8 +405,19 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper( + config, + encoder_hidden_states, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + device=None, + batch_size=1, +): + encoder_hidden_states = encoder( + config, encoder_hidden_states, parameters=parameters.encoder, device=device, batch_size=batch_size + ) last_hidden_state = decoder( config, decoder_hidden_states, @@ -411,6 +428,26 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_audio_classification(config, inputs_embeds, *, parameters, device, batch_size): + encoder_outputs = encoder( + config=config, + inputs_embeds=inputs_embeds, + parameters=parameters.encoder, + device=device, + batch_size=batch_size, + ) + + hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight) + hidden_states = ttnn.add(hidden_states, parameters.projector.bias) + + pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + + logits = ttnn.matmul(pooled_output, parameters.classifier.weight) + logits = ttnn.add(logits, parameters.classifier.bias) + + return logits + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py index 56b8c0054f6..80a32bb04a4 100644 --- a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py @@ -126,16 +126,14 @@ def calculate_query_key_values(config, hidden_states, *, parameters): def whisper_attention(config, hidden_states, attention_mask, key_value_states=None, *, parameters): head_size = config.d_model // config.encoder_attention_heads scaling = head_size**-0.5 - bsz, *_, tgt_len, _ = hidden_states.shape + bsz, *_, tgt_len, tgt_wid = hidden_states.shape is_cross_attention = key_value_states is not None if is_cross_attention: query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias - dtype = query_states.dtype - device = query_states.device() - query_states = ttnn.to_torch(query_states) - query_states = torch.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) - query_states = ttnn.from_torch(query_states, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT) + query_states = ttnn.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT) query_states = ttnn.permute(query_states, (0, 2, 1, 3)) key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters) else: @@ -195,8 +193,14 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device, batch_size): + # issue #7872 + # Add op is not supported for batched inputs (broadcasting not happening) + # hidden_states = inputs_embeds + parameters.embed_positions.weight + weights = ttnn.to_torch(parameters.embed_positions.weight) + embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(weights, embeds) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -396,8 +400,19 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper( + config, + encoder_hidden_states, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + batch_size=None, + device=None, +): + encoder_hidden_states = encoder( + config, encoder_hidden_states, parameters=parameters.encoder, device=device, batch_size=batch_size + ) last_hidden_state = decoder( config, decoder_hidden_states, @@ -408,6 +423,26 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_audio_classification(config, inputs_embeds, *, parameters, device, batch_size): + encoder_outputs = encoder( + config=config, + inputs_embeds=inputs_embeds, + parameters=parameters.encoder, + device=device, + batch_size=batch_size, + ) + + hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight) + hidden_states = ttnn.add(hidden_states, parameters.projector.bias) + + pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + + logits = ttnn.matmul(pooled_output, parameters.classifier.weight) + logits = ttnn.add(logits, parameters.classifier.bias) + + return logits + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/tests/ttnn/integration_tests/whisper/test_demo.py b/tests/ttnn/integration_tests/whisper/test_demo.py new file mode 100644 index 00000000000..7cd17c2b65b --- /dev/null +++ b/tests/ttnn/integration_tests/whisper/test_demo.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from models.experimental.functional_whisper.demo.demo import test_demo_for_audio_classification as demo_audio_files +from models.experimental.functional_whisper.demo.demo import ( + test_demo_for_audio_classification_dataset as demo_audio_dataset, +) +import pytest +from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper + + +@pytest.mark.parametrize( + "input_path", + (("models/experimental/functional_whisper/demo/dataset/audio_classification"),), +) +@pytest.mark.parametrize( + "ttnn_model", + (ttnn_optimized_functional_whisper, ttnn_functional_whisper), +) +@pytest.mark.parametrize( + "batch_size", + ((8),), +) +def test_audio_demo_batch_8(device, reset_seeds, input_path, ttnn_model, batch_size): + expected_answers = { + 0: "English", + 1: "Estonian", + 2: "French", + 3: "Bengali", + 4: "Bengali", + 5: "Estonian", + 6: "English", + 7: "Indonesian", + } + predicted_labels = demo_audio_files(reset_seeds, input_path, ttnn_model, device, batch_size) + + for i in range(batch_size): + assert expected_answers[i] == predicted_labels[i] + + +@pytest.mark.parametrize( + "ttnn_model", + (ttnn_optimized_functional_whisper, ttnn_functional_whisper), +) +@pytest.mark.parametrize( + "batch_size", + ((8),), +) +@pytest.mark.parametrize( + "n_iterations", + ((5),), +) +@pytest.mark.parametrize( + "accuracy", + ((0.7),), +) +def test_audio_demo_dataset(device, reset_seeds, ttnn_model, batch_size, n_iterations, accuracy): + cal_acc = demo_audio_dataset(reset_seeds, ttnn_model, device, batch_size, n_iterations) + assert cal_acc >= accuracy diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index c58ca763995..9b2e9776f54 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -22,7 +22,6 @@ def get_expected_times(functional_whisper): }[functional_whisper] -@skip_for_wormhole_b0(reason_str="Not tested on single WH") @pytest.mark.models_performance_bare_metal @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize("model_name", ["openai/whisper-base"]) @@ -75,6 +74,8 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=parameters, + device=device, + batch_size=1, ) tt_output = ttnn.to_torch(tt_output) end = time.time() diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py index 4281672c5fd..0e6a68d8191 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py @@ -19,7 +19,6 @@ MODEL_NAME = "openai/whisper-base" -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -84,7 +83,6 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ assert_with_pcc(torch_output, output, 0.98) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -120,7 +118,6 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, pcc=0.99) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -167,14 +164,13 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque device=device, ) - output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters, device=device, batch_size=1) output = ttnn.from_device(output) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, 0.97) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -233,7 +229,6 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, 0.97) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -305,7 +300,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): assert_with_pcc(torch_output, output, pcc=0.99) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) def test_ttnn_whisper(device, ttnn_model): torch.manual_seed(0) @@ -370,8 +364,10 @@ def test_ttnn_whisper(device, ttnn_model): decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=ttnn_parameters, + device=device, + batch_size=1, ) last_hidden_state = ttnn.from_device(last_hidden_state) last_hidden_state = ttnn.to_torch(last_hidden_state) - assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.9895) + assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.987) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index 838b8dcbd79..c6188a7f1e4 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -18,7 +18,6 @@ MODEL_NAME = "openai/whisper-base" -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -82,7 +81,6 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ assert_with_pcc(torch_output, output, 0.98) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -119,7 +117,6 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, pcc=0.99) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -166,13 +163,12 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) input_embeds = ttnn.to_device(input_embeds, device) - output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters, device=device, batch_size=1) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, 0.968) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -229,7 +225,6 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, 0.97) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -301,7 +296,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): assert_with_pcc(torch_output, output, pcc=0.99) -@skip_for_wormhole_b0() @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) def test_ttnn_whisper(tmp_path, device, ttnn_model): torch.manual_seed(0) @@ -361,8 +355,10 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=ttnn_parameters, + device=device, + batch_size=1, ) last_hidden_state = ttnn.to_torch(last_hidden_state) ttnn.tracer.visualize(last_hidden_state, file_name=tmp_path / "whisper.svg") - assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.964) + assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.96)