Skip to content

Commit

Permalink
#6343: Add functional_bloom test_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed May 29, 2024
1 parent bc19f9c commit ae75165
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 14 deletions.
37 changes: 29 additions & 8 deletions models/demos/grayskull/functional_bloom/demo/demo_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
from models.demos.grayskull.functional_bloom.dataset_utils import get_data


def generate_next_token(model, config, input_ids, parameters, num_heads, logits_processor, max_length, **kwargs):
def generate_next_token(
model, config, input_ids, parameters, num_heads, logits_processor, device, max_length, **kwargs
):
num_tokens = input_ids.shape[-1]
padded_input_ids, alibi, causal_mask = model.preprocess_inputs(
input_ids=input_ids,
num_heads=num_heads,
max_length=max_length,
attention_mask=None,
device=device,
**kwargs,
)
logits = model.bloom_for_causal_lm(
Expand All @@ -49,6 +52,7 @@ def generate(
logits_processor,
num_heads,
num_tokens_to_decode,
device,
max_length=384,
**kwargs,
):
Expand All @@ -61,6 +65,7 @@ def generate(
parameters,
num_heads,
logits_processor,
device,
max_length,
**kwargs,
)
Expand Down Expand Up @@ -92,6 +97,7 @@ def run_bloom_causal_LM_inference(
input_path,
model_location_generator,
device,
num_tokens_to_decode=10,
):
torch.manual_seed(1234)
config = BloomConfig.from_pretrained(model_version)
Expand Down Expand Up @@ -130,7 +136,7 @@ def run_bloom_causal_LM_inference(
tokenizer=tokenizer,
logits_processor=logits_processor,
num_heads=num_heads,
num_tokens_to_decode=10,
num_tokens_to_decode=num_tokens_to_decode,
device=device,
)

Expand All @@ -151,7 +157,7 @@ def run_bloom_causal_LM_inference(
"post_processing": profiler.get("post_processing_output_to_string"),
}

return measurements
return measurements, generated_text


def run_bloom_causal_LM_inference_hellaswag(
Expand All @@ -160,7 +166,8 @@ def run_bloom_causal_LM_inference_hellaswag(
batch_size,
model_location_generator,
device,
loop_count,
loop_count=5,
num_tokens_to_decode=10,
):
torch.manual_seed(1234)
config = BloomConfig.from_pretrained(model_version)
Expand Down Expand Up @@ -194,7 +201,7 @@ def run_bloom_causal_LM_inference_hellaswag(
tokenizer=tokenizer,
logits_processor=logits_processor,
num_heads=num_heads,
num_tokens_to_decode=10,
num_tokens_to_decode=num_tokens_to_decode,
device=device,
)

Expand All @@ -212,6 +219,8 @@ def run_bloom_causal_LM_inference_hellaswag(
logger.info("Accuracy")
logger.info(accuracy)

return accuracy


@pytest.mark.parametrize(
"functional_model",
Expand All @@ -223,17 +232,20 @@ def test_demo(
model_location_generator,
device,
use_program_cache,
batch_size=8,
num_tokens_to_decode=10,
):
disable_persistent_kernel_cache()
disable_compilation_reports()

return run_bloom_causal_LM_inference(
model_version="bigscience/bloom-560m",
functional_model=functional_model,
batch_size=8,
batch_size=batch_size,
input_path=input_path,
model_location_generator=model_location_generator,
device=device,
num_tokens_to_decode=num_tokens_to_decode,
)


Expand All @@ -245,15 +257,24 @@ def test_demo(
"loop_count",
((4),),
)
def test_demo_hellaswag(model_location_generator, functional_model, device, use_program_cache, loop_count):
def test_demo_hellaswag(
model_location_generator,
functional_model,
device,
use_program_cache,
loop_count,
batch_size=8,
num_tokens_to_decode=10,
):
disable_persistent_kernel_cache()
disable_compilation_reports()

return run_bloom_causal_LM_inference_hellaswag(
model_version="bigscience/bloom-560m",
functional_model=functional_model,
batch_size=8,
batch_size=batch_size,
model_location_generator=model_location_generator,
device=device,
loop_count=loop_count,
num_tokens_to_decode=num_tokens_to_decode,
)
29 changes: 23 additions & 6 deletions models/demos/grayskull/functional_bloom/demo/demo_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def run_bloom_qa_inference(
input_path,
model_location_generator,
device,
num_tokens_to_decode,
reset_seeds,
):
config = BloomConfig.from_pretrained(model_version)
Expand Down Expand Up @@ -141,13 +142,14 @@ def run_bloom_qa_inference(
tokenizer=tokenizer,
logits_processor=logits_processor,
num_heads=num_heads,
num_tokens_to_decode=10,
num_tokens_to_decode=num_tokens_to_decode,
attention_mask=attention_mask,
device=device,
)

profiler.start("post_processing_output_to_string")
generated_text = []
gen_answers = []
for i in range(len(generated_ids)):
generated_text.append(tokenizer.decode(generated_ids[i], skip_special_tokens=True))
profiler.end("post_processing_output_to_string")
Expand All @@ -158,21 +160,23 @@ def run_bloom_qa_inference(
logger.info("Output Prompt")
input_prompt_length = len(input_text[i])
answer = generated_text[i][input_prompt_length:].strip()
gen_answers.append(answer)
logger.info(answer)

measurements = {
"preprocessing_parameter": profiler.get("preprocessing_parameter"),
"post_processing": profiler.get("post_processing_output_to_string"),
}

return measurements
return measurements, gen_answers


def run_bloom_qa_inference_squad(
model_version,
functional_model,
batch_size,
device,
num_tokens_to_decode,
reset_seeds,
):
config = BloomConfig.from_pretrained(model_version)
Expand Down Expand Up @@ -234,7 +238,7 @@ def run_bloom_qa_inference_squad(
tokenizer=tokenizer,
logits_processor=logits_processor,
num_heads=num_heads,
num_tokens_to_decode=10,
num_tokens_to_decode=num_tokens_to_decode,
attention_mask=attention_mask,
device=device,
)
Expand Down Expand Up @@ -266,6 +270,8 @@ def run_bloom_qa_inference_squad(
logger.info("F1 Score :")
logger.info(eval_score["f1"])

return eval_score


@pytest.mark.parametrize(
"functional_model",
Expand All @@ -281,17 +287,20 @@ def test_demo(
device,
use_program_cache,
reset_seeds,
batch_size=8,
num_tokens_to_decode=10,
):
disable_persistent_kernel_cache()
disable_compilation_reports()

return run_bloom_qa_inference(
model_version="bigscience/bloom-560m",
functional_model=functional_model,
batch_size=8,
batch_size=batch_size,
input_path=input_path,
model_location_generator=model_location_generator,
device=device,
num_tokens_to_decode=num_tokens_to_decode,
reset_seeds=reset_seeds,
)

Expand All @@ -303,14 +312,22 @@ def test_demo(
ttnn_optimized_functional_bloom,
),
)
def test_demo_squadv2(functional_model, device, use_program_cache, reset_seeds):
def test_demo_squadv2(
functional_model,
device,
use_program_cache,
reset_seeds,
batch_size=8,
num_tokens_to_decode=10,
):
disable_persistent_kernel_cache()
disable_compilation_reports()

return run_bloom_qa_inference_squad(
model_version="bigscience/bloom-560m",
functional_model=functional_model,
batch_size=8,
batch_size=batch_size,
device=device,
num_tokens_to_decode=num_tokens_to_decode,
reset_seeds=reset_seeds,
)
117 changes: 117 additions & 0 deletions tests/ttnn/integration_tests/bloom/test_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
from loguru import logger
from models.utility_functions import skip_for_wormhole_b0
from models.demos.grayskull.functional_bloom.tt import ttnn_optimized_functional_bloom
from models.demos.grayskull.functional_bloom.demo.demo_causal_lm import test_demo as demo_cg_json
from models.demos.grayskull.functional_bloom.demo.demo_causal_lm import test_demo_hellaswag as demo_cg_hellaswag
from models.demos.grayskull.functional_bloom.demo.demo_qa import test_demo as demo_qa_json
from models.demos.grayskull.functional_bloom.demo.demo_qa import test_demo_squadv2 as demo_qa_squadv2


@pytest.mark.parametrize(
"input_path",
(("models/demos/grayskull/functional_bloom/demo/input_data_causal_lm.json"),),
ids=["default_input"],
)
@pytest.mark.parametrize(
"ttnn_model, batch_size",
((ttnn_optimized_functional_bloom, 7),),
ids=["batch_7"],
)
@skip_for_wormhole_b0()
def test_demo_batch_7_cg(
input_path, ttnn_model, model_location_generator, device, use_program_cache, batch_size, reset_seeds
):
expected_answers = {
0: "A man is sitting on a roof. He is wearing a hat",
1: "A boy is running down a track. He is a man who",
2: "A lady walks to a barbell. She is wearing a black",
3: "Children bring desert out for their family member. The desert is a",
4: "A cat is sitting in a cat bed. The cat is sitting",
5: "We see a bottle of face wash. The bottle is a bottle",
6: "In home pet groomers demonstrate how to make a pet’s",
}
NUM_RUNS = 5
measurements, answers = demo_cg_json(
input_path, ttnn_model, model_location_generator, device, use_program_cache, batch_size, NUM_RUNS
)

logger.info(measurements)
logger.info(answers)

for i in range(batch_size):
assert expected_answers[i] == answers[i]


@pytest.mark.parametrize(
"ttnn_model, batch_size, ref_accuracy",
((ttnn_optimized_functional_bloom, 7, 0.5),),
ids=["batch_7"],
)
@skip_for_wormhole_b0()
def test_demo_squadv2_batch_7_cg(
model_location_generator, ttnn_model, device, use_program_cache, batch_size, ref_accuracy, reset_seeds
):
loop_count = 2
NUM_RUNS = 5
acc = demo_cg_hellaswag(
model_location_generator, ttnn_model, device, use_program_cache, loop_count, batch_size, NUM_RUNS
)
assert acc["accuracy"] >= ref_accuracy


@pytest.mark.parametrize(
"input_path",
(("models/demos/grayskull/functional_bloom/demo/input_data_qa.json"),),
ids=["default_input"],
)
@pytest.mark.parametrize(
"ttnn_model, batch_size",
((ttnn_optimized_functional_bloom, 7),),
ids=["batch_7"],
)
@skip_for_wormhole_b0()
def test_demo_batch_7_qa(
input_path, ttnn_model, model_location_generator, device, use_program_cache, reset_seeds, batch_size
):
expected_answers = {
0: "Chopin's performances were",
1: "The first is the composer",
2: "The early 20th century.",
3: "Yes. He was a",
4: "Beyoncé is a family",
5: "The archbishop of Cant",
6: "The city of the Holy",
}
NUM_RUNS = 5
measurements, answers = demo_qa_json(
input_path, ttnn_model, model_location_generator, device, use_program_cache, reset_seeds, batch_size, NUM_RUNS
)
logger.info(measurements)
logger.info(answers)

for i in range(batch_size):
assert expected_answers[i] == answers[i]


@pytest.mark.parametrize(
"ttnn_model, batch_size, f1",
((ttnn_optimized_functional_bloom, 6, 3.72),),
ids=["batch_6"],
)
@skip_for_wormhole_b0()
def test_demo_squadv2_batch_6_qa(ttnn_model, device, use_program_cache, reset_seeds, batch_size, f1):
loop_count = 5
eval_score = demo_qa_squadv2(
ttnn_model,
device,
use_program_cache,
reset_seeds,
batch_size,
loop_count,
)
assert eval_score["f1"] >= f1

0 comments on commit ae75165

Please sign in to comment.