Skip to content

Commit

Permalink
#8245: functional whisper add support for batched input audio classif…
Browse files Browse the repository at this point in the history
…ication
  • Loading branch information
jayasuryamaganuru committed May 16, 2024
1 parent 1455378 commit 638b482
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 110 deletions.
72 changes: 59 additions & 13 deletions models/experimental/functional_whisper/README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,73 @@
# ttnn_functional_whisper Demo
---

## How to Run
# Functional Whisper Model Demos For Audio Classification and Text Generation

Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for audio classification.
## Introduction

Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for audio classification.
Whisper is a pre-trained model for automatic speech recognition (ASR) and speech translation.The models are trained on either English-only data or multilingual data. The English-only models were trained on the task of speech recognition. The multilingual models were trained on both speech recognition and speech translation tasks.

Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for conditional generation.
The demos showcases the Functional Whisper Model for Audio Classification and Text Generation tasks,
`sanchit-gandhi/whisper-medium-fleurs-lang-id` and `openai/whisper-tiny.en` versions Hugging Face are utilized respective tasks.

Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for conditional generation.
### Details

Our another demo is designed to run with `google/fleurs` for Audio classification and `hf-internal-testing/librispeech_asr_dummy` for Conditional generation

Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification_dataset` to run audio classification demo with dataset input.

Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation_dataset` to run conditional generation demo with dataset input.
The entry point to the Functional Whisper model is the `whisper` function located in `ttnn_optimized_functional_whisper.py`.

## Inputs

Inputs by default are provided from `dataset/audio_classification` and `dataset/conditional_generation` folder. If you wish to change the inputs, provide a different path to demo.

For demo with dataset,Inputs for Audio classification is taken from `google/fleurs` dataset and Inputs for Conditional generation is taken from `hf-internal-testing/librispeech_asr_dummy` dataset.

## Details
## Batch size: 8

Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the `batch_size` to 8

## How to run demo for Audio Classification task

To run the demo for audio classification using the Whisper model, follow these instructions:

- Use the following command to run the whisper for audio classification demo with ttnn optimized functional whisper:
```
`pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[8-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]`
```
- to run the whisper for audio classification demo with ttnn functional whisper use the following command:
```
pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[8-models.experimental.functional_whisper.tt.ttnn_functional_whisper]
```
- our another demo is designed to run with `google/fleurs` dataset for Audio classification, to run the demo for dataset use the command:
```
pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification_dataset
```

## How to run demo for Text Generation task
To run the demo for text generation using the Whisper model, follow these instructions:

- Use the following command to run the whisper for text generation demo with ttnn optimized functional whisper:
```
`pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]`
```
- Use the following command to run the whisper for text generation demo with ttnn functional whisper:
```
pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]
```
- our another demo is designed to run with `hf-internal-testing/librispeech_asr_dummy` for text generation, to run the demo for dataset use the command:
```
pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation_dataset
```

## Results

The text generation demo presents a comprehensive view of the Whisper model's robustness in audio classification and text generation tasks.

Audio classification predicts the languange of the provided audio sample and dataset demo
also provides the accuracy of the model.
for example `batch_size=8` and `n_iterations=3` gives an accuracy of 0.75

For Text generation, the model predicts transcriptions in the same language as the audio (English).

The entry point to whisper model is whisper in `models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py` for optimized version.(`models/experimental/functional_whisper/tt/ttnn_functional_whisper.py` for normal version).
---
163 changes: 97 additions & 66 deletions models/experimental/functional_whisper/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from transformers import AutoFeatureExtractor, WhisperForAudioClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score


def load_input_paths(folder_path):
Expand Down Expand Up @@ -109,9 +110,9 @@ def run_generate(
return ttnn_transcription


def run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs):
torch.manual_seed(1234)

def run_demo_functional_whisper_for_audio_classification_inference(
reset_seeds, input_path, ttnn_model, device, batch_size
):
feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")

Expand All @@ -124,10 +125,11 @@ def run_demo_functional_whisper_for_audio_classification_inference(input_path, t
custom_preprocessor=ttnn_model.custom_preprocessor,
device=device,
)
if len(input_data) < num_inputs:
assert False, "num_inputs exceeds number of audio files available in folder"
if len(input_data) < batch_size:
assert False, "batch_size exceeds number of audio files available in folder"

for i in range(num_inputs):
batched_inputs = []
for i in range(batch_size):
input_file_path = input_data[i]
samplerate, data = wavfile.read(input_file_path)

Expand All @@ -138,30 +140,33 @@ def run_demo_functional_whisper_for_audio_classification_inference(input_path, t
)

input_features = inputs.input_features
if i == 0:
batched_inputs = input_features
else:
batched_inputs = torch.cat((batched_inputs, input_features), dim=0)

config = model.config
input_embedding = ttnn_model.preprocess_encoder_inputs(
input_features=input_features, parameters=parameters.encoder, device=device
)

encoder_outputs = ttnn_model.encoder(
config=config, inputs_embeds=input_embedding, parameters=parameters.encoder
)

hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight)
hidden_states = ttnn.add(hidden_states, parameters.projector.bias)

pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True)
config = model.config
input_embedding = ttnn_model.preprocess_encoder_inputs(
input_features=batched_inputs, parameters=parameters.encoder, device=device
)

logits = ttnn.matmul(pooled_output, parameters.classifier.weight)
logits = ttnn.add(logits, parameters.classifier.bias)
out_logits = ttnn_model.whisper_for_audio_classification(
config=config,
inputs_embeds=input_embedding,
parameters=parameters,
device=device,
batch_size=batch_size,
)

logits_torch = ttnn.to_torch(logits)
predicted_class_ids = torch.argmax(logits_torch).item()
logits_torch = ttnn.to_torch(out_logits)
predicted_list = []
for i in range(batch_size):
single_logits_torch = logits_torch[i].squeeze(0)
predicted_class_ids = torch.argmax(single_logits_torch).item()
predicted_label = model.config.id2label[predicted_class_ids]

logger.info("predicted_label")
logger.info(predicted_label)
logger.info(f"predicted_label: {predicted_label}")
predicted_list.append(predicted_label)
return predicted_list


def run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs):
Expand Down Expand Up @@ -235,56 +240,70 @@ def run_demo_functional_whisper_for_conditional_generation_inference(input_path,
logger.info(output_list[i])


def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device):
torch.manual_seed(1234)

def run_demo_functional_whisper_for_audio_classification_dataset(
reset_seeds, ttnn_model, device, batch_size=8, n_iterations=1
):
feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")

model.eval()

ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
sample = next(iter(ds))

inputs = feature_extractor(
sample["audio"]["array"],
sampling_rate=sample["audio"]["sampling_rate"],
return_tensors="pt",
)

input_features = inputs.input_features

logger.debug("Input audio language:")
logger.debug(sample["language"])
ds_iter = iter(ds)

reference_labels = []
predicted_labels = []
config = model.config
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
convert_to_ttnn=ttnn_model.convert_to_ttnn,
custom_preprocessor=ttnn_model.custom_preprocessor,
device=device,
)

config = model.config
input_embedding = ttnn_model.preprocess_encoder_inputs(
input_features=input_features, parameters=parameters.encoder, device=device
)

encoder_outputs = ttnn_model.encoder(config=config, inputs_embeds=input_embedding, parameters=parameters.encoder)

hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight)
hidden_states = ttnn.add(hidden_states, parameters.projector.bias)

pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True)
for _ in range(n_iterations):
batch_input = []
# prepare the batched audio inputs
for bs in range(batch_size):
sample = next(ds_iter)
inputs = feature_extractor(
sample["audio"]["array"],
sampling_rate=sample["audio"]["sampling_rate"],
return_tensors="pt",
)
input_features = inputs.input_features
if bs == 0:
batch_input = input_features
else:
batch_input = torch.cat((batch_input, input_features), dim=0)
reference_labels.append(sample["language"])

# preprocess the inputs
input_embedding = ttnn_model.preprocess_encoder_inputs(
input_features=batch_input, parameters=parameters.encoder, device=device
)

logits = ttnn.matmul(pooled_output, parameters.classifier.weight)
logits = ttnn.add(logits, parameters.classifier.bias)
# run the model
out_logits = ttnn_model.whisper_for_audio_classification(
config=config,
inputs_embeds=input_embedding,
parameters=parameters,
device=device,
batch_size=batch_size,
)

logits_torch = ttnn.to_torch(logits)
predicted_class_ids = torch.argmax(logits_torch).item()
predicted_label = model.config.id2label[predicted_class_ids]
# postprocessing the outputs
logits_torch = ttnn.to_torch(out_logits)
for i in range(batch_size):
single_logits_torch = logits_torch[i].squeeze(0)
predicted_class_ids = torch.argmax(single_logits_torch).item()
predicted_label = model.config.id2label[predicted_class_ids]
predicted_labels.append(predicted_label)

logger.info("predicted_label")
logger.info(predicted_label)
accuracy = accuracy_score(reference_labels, predicted_labels)
logger.info(f"reference labels: {reference_labels}")
logger.info(f"predicted labels: {predicted_labels}")
logger.info(f"Accuracy: {accuracy}")
return accuracy


def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device):
Expand Down Expand Up @@ -353,13 +372,15 @@ def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, d
(ttnn_optimized_functional_whisper, ttnn_functional_whisper),
)
@pytest.mark.parametrize(
"num_inputs",
((1),),
"batch_size",
((8),),
)
def test_demo_for_audio_classification(input_path, ttnn_model, device, num_inputs):
def test_demo_for_audio_classification(reset_seeds, input_path, ttnn_model, device, batch_size):
disable_persistent_kernel_cache()
disable_compilation_reports()
return run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs)
return run_demo_functional_whisper_for_audio_classification_inference(
reset_seeds, input_path, ttnn_model, device, batch_size
)


@pytest.mark.parametrize(
Expand All @@ -380,10 +401,20 @@ def test_demo_for_conditional_generation(input_path, ttnn_model, device, num_inp
"ttnn_model",
(ttnn_optimized_functional_whisper, ttnn_functional_whisper),
)
def test_demo_for_audio_classification_dataset(ttnn_model, device):
@pytest.mark.parametrize(
"batch_size",
((8),),
)
@pytest.mark.parametrize(
"n_iterations",
((5),),
)
def test_demo_for_audio_classification_dataset(reset_seeds, ttnn_model, device, batch_size, n_iterations):
disable_persistent_kernel_cache()
disable_compilation_reports()
return run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device)
return run_demo_functional_whisper_for_audio_classification_dataset(
reset_seeds, ttnn_model, device, batch_size=batch_size, n_iterations=n_iterations
)


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,14 @@ def encoder_layer(config, hidden_states, *, parameters):
return hidden_states


def encoder(config, inputs_embeds, *, parameters):
hidden_states = inputs_embeds + parameters.embed_positions.weight
def encoder(config, inputs_embeds, *, parameters, device, batch_size):
# issue #7872
# broadcast add is not happening for batched inputs
# hidden_states = inputs_embeds + parameters.embed_positions.weight
weights = ttnn.to_torch(parameters.embed_positions.weight)
embeds = ttnn.to_torch(inputs_embeds)
hidden_states = torch.add(weights, embeds)
hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT)
hidden_states = dropout(hidden_states, p=0, training=False)

for encoder_layer_parameter in parameters.layers:
Expand Down Expand Up @@ -399,8 +405,19 @@ def preprocess_inputs(
return input_embeds, decoder_hidden_states, attention_mask


def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters):
encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder)
def whisper(
config,
encoder_hidden_states,
decoder_hidden_states,
decoder_attention_mask,
*,
parameters,
device=None,
batch_size=1,
):
encoder_hidden_states = encoder(
config, encoder_hidden_states, parameters=parameters.encoder, device=device, batch_size=batch_size
)
last_hidden_state = decoder(
config,
decoder_hidden_states,
Expand All @@ -411,6 +428,26 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent
return last_hidden_state


def whisper_for_audio_classification(config, inputs_embeds, *, parameters, device, batch_size):
encoder_outputs = encoder(
config=config,
inputs_embeds=inputs_embeds,
parameters=parameters.encoder,
device=device,
batch_size=batch_size,
)

hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight)
hidden_states = ttnn.add(hidden_states, parameters.projector.bias)

pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True)

logits = ttnn.matmul(pooled_output, parameters.classifier.weight)
logits = ttnn.add(logits, parameters.classifier.bias)

return logits


def custom_preprocessor(torch_model, name):
parameters = {}
if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention):
Expand Down
Loading

0 comments on commit 638b482

Please sign in to comment.