Skip to content

Commit

Permalink
#4003: Adding whisper functional model
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Dec 23, 2023
1 parent 047d9f6 commit 30bb5a5
Show file tree
Hide file tree
Showing 17 changed files with 2,597 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from transformers import AutoFeatureExtractor, WhisperModel
from datasets import load_dataset
import torch


# Generates the baseline expected results for tests within ttnn
if __name__ == "__main__":
model_name = "openai/whisper-base"
model = WhisperModel.from_pretrained(model_name).to(torch.bfloat16).eval()
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.type(torch.bfloat16)
decoder_input_ids = torch.ones(1, 32).type(torch.int32) * model.config.decoder_start_token_id
parameters = model.state_dict()
print(parameters.keys())
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
print(last_hidden_state.shape)
last_three = last_hidden_state[0, -1, -3:]
print(last_three)
Loading

0 comments on commit 30bb5a5

Please sign in to comment.