-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#4504: Add end to end demo for functional t5 model
- Loading branch information
1 parent
d9daa88
commit a536801
Showing
3 changed files
with
352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# ttnn_functioanl_t5 Demo | ||
|
||
## How to Run | ||
|
||
Use `pytest --disable-warnings --input-path="models/experimental/functional_t5/demo/input_data.json" models/experimental/functional_t5/demo/demo.py::test_functional_t5_demo` to run the functional t5 demo. | ||
|
||
If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="[address_to_your_json_file]" models/experimental/functional_t5/demo/demo.py::test_functional_t5_demo`. This file is expected to have exactly 8 inputs. | ||
|
||
Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/experimental/functional_t5/demo/demo.py::test_functional_t5_demo_squadv2` | ||
|
||
|
||
## Inputs | ||
|
||
Inputs by default are provided from `input_data.json`. If you wish you to change the inputs or provide a different path to `test_functional_t5_demo`. | ||
|
||
We do not recommend modifying `input_data.json` file. | ||
|
||
## Details | ||
|
||
The entry point to metal bert model is `t5_for_conditional_generation` in `ttnn.functional_t5.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `t5-small` and `google/flan-t5-small` versions from huggingface as our reference. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import pytest | ||
import torch | ||
import evaluate | ||
from loguru import logger | ||
from datasets import load_dataset | ||
from models.generation_utils import get_logits_processor | ||
import ttnn | ||
import tt_lib | ||
|
||
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config | ||
from models.experimental.functional_t5.tt import ttnn_functional_t5 | ||
from models.experimental.functional_t5.tt import ttnn_optimized_functional_t5 | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
|
||
from models.utility_functions import ( | ||
disable_compilation_reports, | ||
disable_persistent_kernel_cache, | ||
profiler, | ||
) | ||
|
||
|
||
def load_inputs(input_path, batch): | ||
with open(input_path) as f: | ||
input_data = json.load(f) | ||
assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." | ||
|
||
context = [] | ||
question = [] | ||
for i in range(batch): | ||
context.append(input_data[i]["context"]) | ||
question.append(input_data[i]["question"]) | ||
|
||
return context, question | ||
|
||
|
||
def run_generate(input_ids, model, config, parameters, device, max_tokens, batch_size, use_optimized_version=False): | ||
decoded_tt_output = [] | ||
|
||
logits_processor = get_logits_processor(input_ids, config) | ||
|
||
decoder_start_values = model.generation_config.pad_token_id * torch.ones(1, 128).to(torch.long) | ||
decoder_input_ids = model.generation_config.pad_token_id * torch.ones(batch_size, input_ids.shape[-1]).to( | ||
torch.long | ||
) | ||
|
||
input_ids = ttnn.from_torch(input_ids) | ||
input_ids = ttnn.to_device(input_ids, device) | ||
|
||
for iteration in range(max_tokens): | ||
decoder_input_ids = ttnn.from_torch(decoder_input_ids) | ||
decoder_input_ids = ttnn.to_device(decoder_input_ids, device) | ||
|
||
tt_model = ttnn_optimized_functional_t5 if use_optimized_version else ttnn_functional_t5 | ||
|
||
tt_output, encoder_hidden_states = tt_model.t5_for_conditional_generation( | ||
config, | ||
input_ids, | ||
decoder_input_ids, | ||
parameters=parameters, | ||
) | ||
tt_output = ttnn.from_device(tt_output) | ||
next_token_logits = ttnn.to_torch(tt_output) | ||
|
||
next_tokens_scores = logits_processor(input_ids, next_token_logits) | ||
next_tokens = torch.argmax(next_tokens_scores, dim=-1) | ||
|
||
decoder_input_ids = ttnn.from_device(decoder_input_ids) | ||
decoder_input_ids = ttnn.to_torch(decoder_input_ids) | ||
|
||
if (iteration + 1) % 32 == 0: | ||
decoder_input_ids = torch.cat([decoder_input_ids, decoder_start_values], dim=1) | ||
|
||
decoder_input_ids[:, iteration + 1] = next_tokens[:, iteration] | ||
|
||
return decoder_input_ids | ||
|
||
|
||
def run_functional_t5_question_and_answering_inference( | ||
device, batch_size, sequence_length, max_tokens, model_name, input_path, use_optimized_version | ||
): | ||
config = T5Config.from_pretrained(model_name) | ||
model = T5ForConditionalGeneration.from_pretrained(model_name).eval() | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32) | ||
|
||
context, question = load_inputs(input_path, batch_size) | ||
|
||
input_sentance = [f"question: {q} context: {c}" for q, c in zip(question, context)] | ||
|
||
profiler.start(f"preprocessing_input") | ||
input_ids = tokenizer( | ||
input_sentance, | ||
padding="max_length", | ||
max_length=sequence_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
).input_ids | ||
profiler.end(f"preprocessing_input") | ||
|
||
tt_model_name = "ttnn_" + ("optimized_" if use_optimized_version else "") + model_name | ||
|
||
decoded_tt_output = [] | ||
|
||
custom_preprocessor = ( | ||
ttnn_optimized_functional_t5.custom_preprocessor | ||
if use_optimized_version | ||
else ttnn_functional_t5.custom_preprocessor | ||
) | ||
|
||
profiler.start(f"preprocessing_parameter") | ||
parameters = preprocess_model_parameters( | ||
tt_model_name, | ||
initialize_model=lambda: model, | ||
custom_preprocessor=custom_preprocessor, | ||
device=device, | ||
) | ||
profiler.end(f"preprocessing_parameter") | ||
|
||
profiler.start(f"inference_time") | ||
tt_output = run_generate( | ||
input_ids, | ||
model, | ||
config, | ||
parameters, | ||
device, | ||
max_tokens, | ||
batch_size, | ||
use_optimized_version, | ||
) | ||
profiler.end(f"inference_time") | ||
|
||
profiler.start(f"post_processing_output_to_string") | ||
for batch in range(batch_size): | ||
output = tokenizer.decode(tt_output[batch], skip_special_tokens=True) | ||
decoded_tt_output.append(output) | ||
profiler.end(f"post_processing_output_to_string") | ||
|
||
logger.info(decoded_tt_output) | ||
|
||
measurements = { | ||
"preprocessing_parameter": profiler.get("preprocessing_parameter"), | ||
"preprocessing_input": profiler.get("preprocessing_input"), | ||
"inference_time": profiler.get("inference_time"), | ||
"post_processing": profiler.get("post_processing_output_to_string"), | ||
} | ||
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") | ||
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") | ||
logger.info(f"inference_time: {measurements['inference_time']} s") | ||
logger.info(f"post_processing : {measurements['post_processing']} s") | ||
|
||
return measurements | ||
|
||
|
||
def run_functional_t5_question_and_answering_inference_squadv2( | ||
device, batch_size, sequence_length, max_tokens, model_name, use_optimized_version | ||
): | ||
config = T5Config.from_pretrained(model_name) | ||
model = T5ForConditionalGeneration.from_pretrained(model_name).eval() | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32) | ||
|
||
squad_dataset = load_dataset("squad_v2") | ||
validation_split = squad_dataset["validation"] | ||
predicted_answers = [] | ||
reference_answer = [] | ||
decoded_tt_output = [] | ||
|
||
tt_model_name = "ttnn_" + ("optimized_" if use_optimized_version else "") + model_name | ||
custom_preprocessor = ( | ||
ttnn_optimized_functional_t5.custom_preprocessor | ||
if use_optimized_version | ||
else ttnn_functional_t5.custom_preprocessor | ||
) | ||
|
||
parameters = preprocess_model_parameters( | ||
tt_model_name, | ||
initialize_model=lambda: model, | ||
custom_preprocessor=custom_preprocessor, | ||
device=device, | ||
) | ||
|
||
question = [] | ||
context = [] | ||
answers = [] | ||
id = [] | ||
|
||
index = 0 | ||
while index < batch_size: | ||
answer = validation_split["answers"][index] | ||
if len(answer["text"]) > 0: | ||
question.append(validation_split["question"][index]) | ||
context.append(validation_split["context"][index]) | ||
answers.append(validation_split["answers"][index]) | ||
id.append(validation_split["id"][index]) | ||
index += 1 | ||
else: | ||
continue | ||
|
||
input_sentance = [f"question: {q} context: {c}" for q, c in zip(question, context)] | ||
|
||
input_ids = tokenizer( | ||
input_sentance, | ||
padding="max_length", | ||
max_length=sequence_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
).input_ids | ||
|
||
tt_output = run_generate( | ||
input_ids, | ||
model, | ||
config, | ||
parameters, | ||
device, | ||
max_tokens, | ||
batch_size, | ||
use_optimized_version, | ||
) | ||
|
||
for batch in range(batch_size): | ||
output = tokenizer.decode(tt_output[batch], skip_special_tokens=True) | ||
decoded_tt_output.append(output) | ||
|
||
logger.info(decoded_tt_output) | ||
|
||
for batch in range(batch_size): | ||
predicted_answers.append( | ||
{ | ||
"prediction_text": decoded_tt_output[batch], | ||
"id": id[batch], | ||
"no_answer_probability": 0.0, | ||
} | ||
) | ||
reference_answer.append( | ||
{ | ||
"answers": { | ||
"answer_start": [answers[batch]["answer_start"][0]], | ||
"text": [answers[batch]["text"][0]], | ||
}, | ||
"id": id[batch], | ||
} | ||
) | ||
squad_metric = evaluate.load("squad_v2") | ||
eval_score = squad_metric.compute(predictions=predicted_answers, references=reference_answer) | ||
logger.info("Exact Match :") | ||
logger.info(eval_score["exact"]) | ||
logger.info("F1 Score :") | ||
logger.info(eval_score["f1"]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("batch_size", "sequence_length", "max_tokens", "model_name", "use_optimized_version"), | ||
( | ||
(8, 384, 5, "t5-small", False), | ||
(8, 384, 5, "google/flan-t5-small", False), | ||
), | ||
) | ||
def test_functional_t5_demo( | ||
device, batch_size, sequence_length, max_tokens, model_name, input_path, use_optimized_version | ||
): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_functional_t5_question_and_answering_inference( | ||
device, batch_size, sequence_length, max_tokens, model_name, input_path, use_optimized_version | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("batch_size", "sequence_length", "max_tokens", "model_name", "use_optimized_version"), | ||
((3, 384, 5, "t5-small", False), (3, 384, 5, "google/flan-t5-small", False)), | ||
) | ||
def test_functional_t5_demo_squadv2(device, batch_size, sequence_length, max_tokens, model_name, use_optimized_version): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_functional_t5_question_and_answering_inference_squadv2( | ||
device, batch_size, sequence_length, max_tokens, model_name, use_optimized_version | ||
) |
Oops, something went wrong.