-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#4505: Add end to end demo for functional bert model
- Loading branch information
1 parent
625f3b6
commit 85c424f
Showing
3 changed files
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
## functional_bert Demo | ||
## How to Run | ||
|
||
Use `pytest --disable-warnings --input-path="models/experimental/functional_bert/demo/input_data.json" models/experimental/functional_bert/demo/demo.py::test_demo[models.experimental.functional_bert.tt.ttnn_functional_bert-phiyodr/bert-large-finetuned-squad2]` to run the demo. | ||
|
||
If you wish to run the demo for ttnn_optimized_functional_bert, use `pytest --disable-warnings --input-path="models/experimental/functional_bert/demo/input_data.json" models/experimental/functional_bert/demo/demo.py::test_demo[models.experimental.functional_bert.tt.ttnn_optimized_functional_bert-phiyodr/bert-large-finetuned-squad2]` to run the demo. | ||
|
||
If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="<address_to_your_json_file.json>" models/experimental/functional_bert/demo/demo.py::test_demo[models.experimental.functional_bert.tt.ttnn_functional_bert-phiyodr/bert-large-finetuned-squad2]`. This file is expected to have exactly 8 inputs. | ||
|
||
Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/experimental/functional_bert/demo/demo.py::test_demo_squadv2[3-models.experimental.functional_bert.tt.ttnn_optimized_functional_bert-phiyodr/bert-large-finetuned-squad2]`. | ||
|
||
If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/experimental/functional_bert/demo/demo.py::test_demo_squadv2[<n_iterations>-models.experimental.functional_bert.tt.ttnn_optimized_functional_bert-phiyodr/bert-large-finetuned-squad2]` | ||
|
||
|
||
# Inputs | ||
Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo. | ||
|
||
We do not recommend modifying `input_data.json` file. | ||
|
||
# Details | ||
The entry point to functional_bert model is bert_for_question_answering in `models/experimental/functional_bert/tt/ttnn_functional_bert.py` (`models/experimental/functional_bert/tt/ttnn_optimized_functional_bert.py` for optimized version). The model picks up certain configs and weights from huggingface pretrained model. We have used `phiyodr/bert-large-finetuned-squad2` version from huggingface as our reference. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import pytest | ||
import torch | ||
from loguru import logger | ||
|
||
import transformers | ||
import ttnn | ||
import tt_lib | ||
from models.utility_functions import ( | ||
disable_compilation_reports, | ||
disable_persistent_kernel_cache, | ||
profiler, | ||
) | ||
from models.experimental.functional_bert.tt import ttnn_functional_bert | ||
from models.experimental.functional_bert.tt import ttnn_optimized_functional_bert | ||
|
||
from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch | ||
from ttnn.model_preprocessing import ( | ||
preprocess_model_parameters, | ||
) | ||
|
||
from ttnn.model_preprocessing import * | ||
from transformers import BertForQuestionAnswering, BertTokenizer, pipeline | ||
|
||
import evaluate | ||
|
||
|
||
def load_inputs(input_path, batch): | ||
with open(input_path) as f: | ||
input_data = json.load(f) | ||
assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." | ||
|
||
context = [] | ||
question = [] | ||
for i in range(batch): | ||
context.append(input_data[i]["context"]) | ||
question.append(input_data[i]["question"]) | ||
|
||
return context, question | ||
|
||
|
||
def run_bert_question_and_answering_inference( | ||
device, | ||
use_program_cache, | ||
model_name, | ||
batch_size, | ||
sequence_size, | ||
functional_bert, | ||
model_location_generator, | ||
input_path, | ||
): | ||
disable_persistent_kernel_cache() | ||
|
||
model = str(model_location_generator(model_name, model_subdir="Bert")) | ||
hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model, torchscript=False) | ||
hugging_face_reference_model.eval() | ||
|
||
# set up tokenizer | ||
tokenizer_name = str(model_location_generator(model_name, model_subdir="Bert")) | ||
tokenizer = BertTokenizer.from_pretrained(tokenizer_name) | ||
config = hugging_face_reference_model.config | ||
nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) | ||
|
||
if functional_bert == ttnn_functional_bert: | ||
tt_model_name = f"ttnn_{model_name}" | ||
elif functional_bert == ttnn_optimized_functional_bert: | ||
tt_model_name = f"ttnn_{model_name}_optimized" | ||
else: | ||
raise ValueError(f"Unknown functional_bert: {functional_bert}") | ||
|
||
profiler.start(f"preprocessing_parameter") | ||
parameters = preprocess_model_parameters( | ||
tt_model_name, | ||
initialize_model=lambda: transformers.BertForQuestionAnswering.from_pretrained( | ||
model_name, torchscript=False | ||
).eval(), | ||
custom_preprocessor=functional_bert.custom_preprocessor, | ||
device=device, | ||
) | ||
profiler.end(f"preprocessing_parameter") | ||
|
||
context, question = load_inputs(input_path, batch_size) | ||
|
||
preprocess_params, _, postprocess_params = nlp._sanitize_parameters() | ||
preprocess_params["max_seq_len"] = sequence_size | ||
inputs = nlp._args_parser({"context": context, "question": question}) | ||
preprocessed_inputs = [] | ||
for i in range(batch_size): | ||
model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params)) | ||
single_input = { | ||
"example": model_input["example"], | ||
"inputs": model_input, | ||
} | ||
preprocessed_inputs.append(single_input) | ||
|
||
bert_input = tokenizer.batch_encode_plus( | ||
zip(question, context), | ||
max_length=sequence_size, | ||
padding="max_length", | ||
truncation=True, | ||
return_attention_mask=True, | ||
return_token_type_ids=True, | ||
return_tensors="pt", | ||
) | ||
profiler.start(f"preprocessing_input") | ||
ttnn_bert_inputs = functional_bert.preprocess_inputs( | ||
bert_input["input_ids"], | ||
bert_input["token_type_ids"], | ||
torch.zeros(1, sequence_size) if functional_bert == ttnn_optimized_functional_bert else None, | ||
device=device, | ||
) | ||
profiler.end(f"preprocessing_input") | ||
|
||
profiler.start(f"inference_time") | ||
tt_output = functional_bert.bert_for_question_answering( | ||
config, | ||
*ttnn_bert_inputs, | ||
parameters=parameters, | ||
) | ||
profiler.end(f"inference_time") | ||
|
||
tt_output = ttnn.to_torch(ttnn.from_device(tt_output)).reshape(batch_size, 1, sequence_size, -1).to(torch.float32) | ||
|
||
tt_start_logits = tt_output[..., :, 0].squeeze(1) | ||
tt_end_logits = tt_output[..., :, 1].squeeze(1) | ||
|
||
model_answers = {} | ||
profiler.start("post_processing_output_to_string") | ||
for i in range(batch_size): | ||
tt_res = { | ||
"start": tt_start_logits[i], | ||
"end": tt_end_logits[i], | ||
"example": preprocessed_inputs[i]["example"], | ||
**preprocessed_inputs[i]["inputs"], | ||
} | ||
|
||
tt_answer = nlp.postprocess([tt_res], **postprocess_params) | ||
|
||
logger.info(f"answer: {tt_answer['answer']}\n") | ||
model_answers[i] = tt_answer["answer"] | ||
|
||
profiler.end("post_processing_output_to_string") | ||
|
||
measurements = { | ||
"preprocessing_parameter": profiler.get("preprocessing_parameter"), | ||
"preprocessing_input": profiler.get("preprocessing_input"), | ||
"inference_time": profiler.get("inference_time"), | ||
"post_processing": profiler.get("post_processing_output_to_string"), | ||
} | ||
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") | ||
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") | ||
logger.info(f"inference_time: {measurements['inference_time']} s") | ||
logger.info(f"post_processing : {measurements['post_processing']} s") | ||
|
||
return measurements | ||
|
||
|
||
def run_bert_question_and_answering_inference_squad_v2( | ||
device, | ||
use_program_cache, | ||
model_name, | ||
batch_size, | ||
sequence_size, | ||
functional_bert, | ||
model_location_generator, | ||
n_iterations, | ||
): | ||
disable_persistent_kernel_cache() | ||
|
||
model = str(model_location_generator(model_name, model_subdir="Bert")) | ||
hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model, torchscript=False) | ||
hugging_face_reference_model.eval() | ||
|
||
# set up tokenizer | ||
tokenizer_name = str(model_location_generator(model_name, model_subdir="Bert")) | ||
tokenizer = BertTokenizer.from_pretrained(tokenizer_name) | ||
config = hugging_face_reference_model.config | ||
|
||
if functional_bert == ttnn_functional_bert: | ||
tt_model_name = f"ttnn_{model_name}" | ||
elif functional_bert == ttnn_optimized_functional_bert: | ||
tt_model_name = f"ttnn_{model_name}_optimized" | ||
else: | ||
raise ValueError(f"Unknown functional_bert: {functional_bert}") | ||
|
||
parameters = preprocess_model_parameters( | ||
tt_model_name, | ||
initialize_model=lambda: transformers.BertForQuestionAnswering.from_pretrained( | ||
model_name, torchscript=False | ||
).eval(), | ||
custom_preprocessor=functional_bert.custom_preprocessor, | ||
device=device, | ||
) | ||
|
||
nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) | ||
|
||
attention_mask = True | ||
token_type_ids = True | ||
inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size) | ||
squad_metric = evaluate.load("squad_v2") | ||
|
||
with torch.no_grad(): | ||
pred_labels = [] | ||
cpu_pred_labels = [] | ||
true_labels = [] | ||
i = 0 | ||
for batch in inputs_squadv2: | ||
if i < n_iterations: | ||
batch_data = batch[0] | ||
curr_batch_size = batch_data["input_ids"].shape[0] | ||
ttnn_bert_inputs = functional_bert.preprocess_inputs( | ||
batch_data["input_ids"], | ||
batch_data["token_type_ids"], | ||
torch.zeros(1, sequence_size) if functional_bert == ttnn_optimized_functional_bert else None, | ||
device=device, | ||
) | ||
|
||
tt_output = functional_bert.bert_for_question_answering( | ||
config, | ||
*ttnn_bert_inputs, | ||
parameters=parameters, | ||
) | ||
tt_output = ( | ||
ttnn.to_torch(ttnn.from_device(tt_output)) | ||
.reshape(batch_size, 1, sequence_size, -1) | ||
.to(torch.float32) | ||
) | ||
cpu_output = hugging_face_reference_model(**batch_data) | ||
references = batch[1] | ||
question = batch[2] | ||
context = batch[3] | ||
|
||
cpu_predictions, tt_predictions = squadv2_answer_decode_batch( | ||
hugging_face_reference_model, | ||
tokenizer, | ||
nlp, | ||
references, | ||
cpu_output, | ||
tt_output, | ||
curr_batch_size, | ||
question, | ||
context, | ||
) | ||
pred_labels.extend(tt_predictions) | ||
cpu_pred_labels.extend(cpu_predictions) | ||
true_labels.extend(references) | ||
|
||
del tt_output | ||
i += 1 | ||
eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) | ||
cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) | ||
logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") | ||
logger.info(f"\tCPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}") | ||
|
||
|
||
@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"]) | ||
@pytest.mark.parametrize("functional_bert", [ttnn_functional_bert, ttnn_optimized_functional_bert]) | ||
def test_demo( | ||
input_path, | ||
model_name, | ||
functional_bert, | ||
model_location_generator, | ||
device, | ||
use_program_cache, | ||
): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
tt_lib.profiler.set_profiler_location(f"tt_metal/tools/profiler/logs/functional_bert") | ||
return run_bert_question_and_answering_inference( | ||
device=device, | ||
use_program_cache=use_program_cache, | ||
model_name=model_name, | ||
batch_size=8, | ||
sequence_size=384, | ||
functional_bert=functional_bert, | ||
model_location_generator=model_location_generator, | ||
input_path=input_path, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"]) | ||
@pytest.mark.parametrize("functional_bert", [ttnn_functional_bert, ttnn_optimized_functional_bert]) | ||
@pytest.mark.parametrize( | ||
"n_iterations", | ||
((3),), | ||
) | ||
def test_demo_squadv2( | ||
model_name, | ||
functional_bert, | ||
n_iterations, | ||
model_location_generator, | ||
device, | ||
use_program_cache, | ||
): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_bert_question_and_answering_inference_squad_v2( | ||
device=device, | ||
use_program_cache=use_program_cache, | ||
model_name=model_name, | ||
batch_size=8, | ||
sequence_size=384, | ||
functional_bert=functional_bert, | ||
model_location_generator=model_location_generator, | ||
n_iterations=n_iterations, | ||
) |
Oops, something went wrong.