From 0948b0429231711adbf9d909539701a0b7cc0823 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Wed, 8 May 2024 13:27:03 +0000 Subject: [PATCH] #8246: functional_whisper CG test_demo --- .../functional_whisper/demo/demo.py | 314 ++++++++++-------- .../tt/ttnn_functional_whisper.py | 30 +- .../tt/ttnn_optimized_functional_whisper.py | 30 +- .../integration_tests/whisper/test_demo.py | 67 ++++ .../whisper/test_ttnn_functional_whisper.py | 3 +- .../test_ttnn_optimized_functional_whisper.py | 3 +- 6 files changed, 304 insertions(+), 143 deletions(-) create mode 100644 tests/ttnn/integration_tests/whisper/test_demo.py diff --git a/models/experimental/functional_whisper/demo/demo.py b/models/experimental/functional_whisper/demo/demo.py index 79d364e070d..f0a26e3d954 100644 --- a/models/experimental/functional_whisper/demo/demo.py +++ b/models/experimental/functional_whisper/demo/demo.py @@ -18,9 +18,10 @@ from models.utility_functions import ( disable_compilation_reports, disable_persistent_kernel_cache, + profiler, ) from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper -from models.generation_utils import get_logits_processor +from models.generation_utils import get_logits_processor, pad_input_32 from ttnn.model_preprocessing import preprocess_model_parameters import torch @@ -30,6 +31,7 @@ from transformers import AutoFeatureExtractor, WhisperForAudioClassification from datasets import load_dataset +from torchmetrics.text import WordErrorRate def load_input_paths(folder_path): @@ -37,20 +39,6 @@ def load_input_paths(folder_path): return files -def pad_input_32(tensor, value): - len = tensor.shape[1] - - if len % 32 == 0: - return tensor - - padded_len = ((len // 32) + 1) * 32 - - pad_tensor = (value * torch.ones(tensor.shape[0], padded_len - len)).to(torch.long) - tensor = torch.cat([tensor, pad_tensor], dim=1) - - return tensor - - def run_generate( config, input_embeds, @@ -59,54 +47,59 @@ def run_generate( decoder_hidden_states, decoder_attention_mask, parameters, - processor, ttnn_linear_weight, device, + decoder_input_ids, generation_config, + batch_size, + max_tokens, ): - input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - logits_processor = get_logits_processor(input_ids, config) - - input_ids = pad_input_32(input_ids, config.pad_token_id).to(torch.long) - - decoder_start_values = generation_config.pad_token_id * torch.ones(1, 32).to(torch.long) + logits_processor = get_logits_processor(decoder_input_ids, config) + decoder_start_values = generation_config.pad_token_id * torch.ones(batch_size, input_features.shape[1]).to( + torch.long + ) + eos_reached = torch.zeros(batch_size, dtype=torch.bool) - for i in range(32): - output = ttnn_model.whisper( - config, - input_embeds, - decoder_hidden_states, + profiler.start(f"inference_time") + for i in range(max_tokens): + ttnn_output = ttnn_model.whisper_for_conditional_generation( + config=config, + input_embeds=input_embeds, + decoder_hidden_states=decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=parameters, + device=device, + ttnn_linear_weight=ttnn_linear_weight, ) - output = output @ ttnn_linear_weight - - output = ttnn.from_device(output) - - logits_to_torch = ttnn.to_torch(output) - + ttnn_output = ttnn.from_device(ttnn_output) + logits_to_torch = ttnn.to_torch(ttnn_output) next_token_logits = logits_to_torch[:, i, :] - next_tokens_scores = logits_processor(input_features, next_token_logits) - next_tokens = torch.argmax(next_tokens_scores, dim=-1) + next_tokens = torch.argmax(next_tokens_scores, dim=-1).unsqueeze(0) - if (i + 1) % 32 == 0: - input_ids = torch.cat([input_ids, decoder_start_values], dim=1) + # Check if EOS token is generated for any sample in the batch and + # Setting subsequent next_tokens to config.pad_token_id if EOS token is reached. + eos_generated_flags = next_tokens == config.eos_token_id + eos_reached = eos_reached | eos_generated_flags.squeeze(0) + next_tokens[:, eos_reached] = config.pad_token_id - input_ids[:, i + 1] = next_tokens[:, None] + if (i + 1) % 32 == 0: + decoder_input_ids = torch.cat([decoder_input_ids, decoder_start_values], dim=1) + decoder_input_ids[:, i + 1] = next_tokens[:, None] decoder_hidden_states, decoder_attention_mask = ttnn_model.preprocess_decoder_inputs( - config=config, input_ids=input_ids, attention_mask=None, parameters=parameters.decoder, device=device + config=config, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=parameters.decoder, + device=device, ) - if next_tokens == config.eos_token_id: + if torch.all(next_tokens == config.eos_token_id): break - logger.info(processor.batch_decode(input_ids, skip_special_tokens=True)[0]) - ttnn_transcription = processor.batch_decode(input_ids, skip_special_tokens=True)[0] - - return ttnn_transcription + profiler.end(f"inference_time") + return decoder_input_ids def run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs): @@ -164,16 +157,13 @@ def run_demo_functional_whisper_for_audio_classification_inference(input_path, t logger.info(predicted_label) -def run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs): - torch.manual_seed(0) - +def run_demo_functional_whisper_for_conditional_generation_inference( + input_path, ttnn_model, device, reset_seeds, batch_size=1, max_tokens=32 +): model = WhisperModel.from_pretrained("openai/whisper-tiny.en").to(torch.bfloat16).eval() - config = WhisperConfig.from_pretrained("openai/whisper-tiny.en") - processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe") hf_reference_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - linear_weight = hf_reference_model.proj_out.weight linear_weight = hf_reference_model.proj_out.weight ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) @@ -189,50 +179,69 @@ def run_demo_functional_whisper_for_conditional_generation_inference(input_path, device=device, ) - if len(input_data) < num_inputs: - assert False, "num_inputs exceeds number of audio files available in folder" - output_list = {} - for i in range(num_inputs): + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" + + for i in range(batch_size): input_file_path = input_data[i] samplerate, data = wavfile.read(input_file_path) inputs = feature_extractor(data, sampling_rate=samplerate, return_tensors="pt") dtype_to_use = torch.bfloat16 input_features = inputs.input_features.type(dtype_to_use) + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) - attention_mask = None + profiler.start(f"preprocessing_inputs") + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=parameters, + device=device, + ) + profiler.end(f"preprocessing_inputs") - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) + generation_config = hf_reference_model.generation_config + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + ) - generation_config = hf_reference_model.generation_config - ttnn_output = run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - processor=processor, - ttnn_linear_weight=ttnn_linear_weight, - device=device, - generation_config=generation_config, - ) - logger.info("Model Output") - logger.info(ttnn_output) - output_list[i] = ttnn_output - for i in range(len(output_list)): - logger.info(f"output for input {i+1}") - logger.info(output_list[i]) + profiler.start(f"post_processing_output_to_string") + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) + profiler.end(f"post_processing_output_to_string") + + logger.info("Model Output") + logger.info(ttnn_transcription) + + measurements = { + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements, ttnn_transcription def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device): @@ -287,16 +296,13 @@ def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, dev logger.info(predicted_label) -def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device): - torch.manual_seed(0) - +def run_demo_functional_whisper_for_conditional_generation_dataset( + ttnn_model, device, reset_seeds, batch_size=1, n_iterations=1, max_tokens=32 +): model = WhisperModel.from_pretrained("openai/whisper-tiny.en").to(torch.bfloat16).eval() - config = WhisperConfig.from_pretrained("openai/whisper-tiny.en") - processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe") hf_reference_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - linear_weight = hf_reference_model.proj_out.weight linear_weight = hf_reference_model.proj_out.weight ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) @@ -305,47 +311,77 @@ def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, d feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) + sample = iter(ds) + batched_ground_truth_transcriptions = [] + + for _ in range(n_iterations): + for i in range(batch_size): + s = next(sample) + inputs = feature_extractor(s["audio"]["array"], sampling_rate=16000, return_tensors="pt") + ground_truth_transcriptions = s["text"] + dtype_to_use = torch.bfloat16 + input_features = inputs.input_features.type(dtype_to_use) + + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) + + batched_ground_truth_transcriptions.append(ground_truth_transcriptions) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.eval(), + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=device, + ) - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=parameters, + device=device, + ) - attention_mask = None + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=hf_reference_model.generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + ) + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) - parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) + logger.info("Model Output") + logger.info(ttnn_transcription) - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) + wer = WordErrorRate() + wer_scores = [] + for transcription, ground_truth in zip(ttnn_transcription, batched_ground_truth_transcriptions): + transcription = transcription.upper() + individual_wer_score = wer([transcription], [ground_truth]) + wer_scores.append(individual_wer_score) + logger.info(f"Individual Sample WER score: {individual_wer_score}") - generation_config = hf_reference_model.generation_config - ttnn_output = run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - processor=processor, - ttnn_linear_weight=ttnn_linear_weight, - device=device, - generation_config=generation_config, - ) - logger.info("Model Output") - logger.info(ttnn_output) + average_wer_score = sum(wer_scores) / len(wer_scores) + logger.info(f"Average WER score: {average_wer_score}") + accuracy = 1 - average_wer_score + logger.info(f"Accuracy: {accuracy}") + + return average_wer_score @pytest.mark.parametrize( @@ -367,13 +403,17 @@ def test_demo_for_audio_classification(input_path, ttnn_model, device, num_input (ttnn_optimized_functional_whisper, ttnn_functional_whisper), ) @pytest.mark.parametrize( - "num_inputs", - ((1),), + ("batch_size", "max_tokens"), + ((8, 32),), ) -def test_demo_for_conditional_generation(input_path, ttnn_model, device, num_inputs): +def test_demo_for_conditional_generation( + input_path, ttnn_model, device, use_program_cache, reset_seeds, batch_size, max_tokens +): disable_persistent_kernel_cache() disable_compilation_reports() - return run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs) + return run_demo_functional_whisper_for_conditional_generation_inference( + input_path, ttnn_model, device, reset_seeds, batch_size, max_tokens + ) @pytest.mark.parametrize( @@ -388,9 +428,17 @@ def test_demo_for_audio_classification_dataset(ttnn_model, device): @pytest.mark.parametrize( "ttnn_model", - (ttnn_functional_whisper, ttnn_optimized_functional_whisper), + (ttnn_optimized_functional_whisper, ttnn_functional_whisper), ) -def test_demo_for_conditional_generation_dataset(ttnn_model, device): +@pytest.mark.parametrize( + ("batch_size", "n_iterations", "max_tokens"), + ((8, 1, 32),), +) +def test_demo_for_conditional_generation_dataset( + ttnn_model, device, use_program_cache, reset_seeds, batch_size, n_iterations, max_tokens +): disable_persistent_kernel_cache() disable_compilation_reports() - return run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device) + return run_demo_functional_whisper_for_conditional_generation_dataset( + ttnn_model, device, reset_seeds, batch_size, n_iterations, max_tokens + ) diff --git a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py index 30b5fa712cf..8f1f2cb837d 100644 --- a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py @@ -198,8 +198,11 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device): + weights = ttnn.to_torch(parameters.embed_positions.weight) + inputs_embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(inputs_embeds, weights) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -399,8 +402,8 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters, device): + encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder, device=device) last_hidden_state = decoder( config, decoder_hidden_states, @@ -411,6 +414,25 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_conditional_generation( + config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters, device, ttnn_linear_weight +): + output = whisper( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ) + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py index 56b8c0054f6..326777ef154 100644 --- a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py @@ -195,8 +195,11 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device): + weights = ttnn.to_torch(parameters.embed_positions.weight) + inputs_embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(inputs_embeds, weights) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -396,8 +399,8 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters, device): + encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder, device=device) last_hidden_state = decoder( config, decoder_hidden_states, @@ -408,6 +411,25 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_conditional_generation( + config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters, device, ttnn_linear_weight +): + output = whisper( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ) + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/tests/ttnn/integration_tests/whisper/test_demo.py b/tests/ttnn/integration_tests/whisper/test_demo.py new file mode 100644 index 00000000000..5c0fa5363c3 --- /dev/null +++ b/tests/ttnn/integration_tests/whisper/test_demo.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from loguru import logger +from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper +from models.experimental.functional_whisper.demo.demo import test_demo_for_conditional_generation as demo +from models.experimental.functional_whisper.demo.demo import ( + test_demo_for_conditional_generation_dataset as demo_dataset, +) + + +@pytest.mark.parametrize( + "input_path", + (("models/experimental/functional_whisper/demo/dataset/conditional_generation"),), + ids=["default_input"], +) +@pytest.mark.parametrize( + "batch_size", + (10,), + ids=["batch_10"], +) +@pytest.mark.parametrize( + "ttnn_model", + (ttnn_optimized_functional_whisper,), +) +def test_demo_batch_10(input_path, ttnn_model, device, use_program_cache, reset_seeds, batch_size): + expected_answers = { + 0: " As soon as you ", + 1: " Some festivals have special", + 2: " The original population hasn", + 3: " Although three people ", + 4: " Soon, officers", + 5: " Water is spilling over", + 6: " Naturalist and Philos", + 7: " With only 18 metals", + 8: " Scientists say the explosion", + 9: " According to police the", + } + NUM_RUNS = 5 + measurements, answers = demo(input_path, ttnn_model, device, use_program_cache, reset_seeds, batch_size, NUM_RUNS) + + logger.info(measurements) + logger.info(answers) + + +@pytest.mark.parametrize( + "batch_size, wer", + ( + ( + 7, + 0.86, + ), + ), + ids=["batch_7"], +) +@pytest.mark.parametrize( + "ttnn_model", + (ttnn_optimized_functional_whisper, ttnn_functional_whisper), +) +def test_demo_squadv2_batch_7(ttnn_model, device, reset_seeds, batch_size, wer, use_program_cache): + loop_count = 5 + evals = demo_dataset( + ttnn_model, device, use_program_cache, reset_seeds, batch_size, n_iterations=1, max_tokens=loop_count + ) + assert evals <= wer diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py index 4281672c5fd..bc9345b73ab 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py @@ -167,7 +167,7 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque device=device, ) - output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters, device=device) output = ttnn.from_device(output) output = ttnn.to_torch(output) @@ -370,6 +370,7 @@ def test_ttnn_whisper(device, ttnn_model): decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=ttnn_parameters, + device=device, ) last_hidden_state = ttnn.from_device(last_hidden_state) last_hidden_state = ttnn.to_torch(last_hidden_state) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index 838b8dcbd79..e9b2aaf8ee0 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -166,7 +166,7 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) input_embeds = ttnn.to_device(input_embeds, device) - output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters, device=device) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, 0.968) @@ -361,6 +361,7 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=ttnn_parameters, + device=device, ) last_hidden_state = ttnn.to_torch(last_hidden_state) ttnn.tracer.visualize(last_hidden_state, file_name=tmp_path / "whisper.svg")