Skip to content

Commit

Permalink
#3918: revert perf file changes sue to hang | fix demo perf profile
Browse files Browse the repository at this point in the history
  • Loading branch information
farbabi committed Nov 24, 2023
1 parent 35370cc commit a71ef03
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 63 deletions.
123 changes: 67 additions & 56 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import tt_lib
import torch
from loguru import logger
import time

from transformers import AutoTokenizer

from models.demos.falcon7b.tt.falcon_causallm import TtFalconCausalLM

from models.demos.falcon7b.reference.hf_modeling_falcon import FalconConfig
from models.demos.falcon7b.reference.hf_modeling_falcon import FalconConfig, FalconForCausalLM
from models.demos.falcon7b.tt.model_config import get_model_config, get_tt_cache_path, model_config_entries
from models.utility_functions import (
disable_compilation_reports,
Expand All @@ -24,8 +24,6 @@
tt2torch_tensor,
)

import time

END_OF_TEXT = 11
SPACE = 204

Expand Down Expand Up @@ -108,9 +106,16 @@ def run_falcon_demo_kv(
profiler.end(f"loading_inputs")

# State dict is needed for embeddings
logger.info("Loading TT model weights")
profiler.start(f"loading_weights")
state_dict = {"transformer.word_embeddings.weight": torch.load(tt_cache_path / "embedding.pt")}
logger.info("Loading TT weights and model...")
profiler.start(f"loading_weights_model")
if tt_cache_path:
state_dict = {"transformer.word_embeddings.weight": torch.load(tt_cache_path / "embedding.pt")}
else:
model_name = model_location_generator(model_version, model_subdir="Falcon")
hugging_face_reference_model = FalconForCausalLM.from_pretrained(model_name)
hugging_face_reference_model.eval()
state_dict = hugging_face_reference_model.state_dict()

tt_lib.device.Synchronize()

base_url = ""
Expand All @@ -125,9 +130,10 @@ def run_falcon_demo_kv(
tt_cache_path,
)

logger.info("Loaded TT weights and model")
profiler.end(f"loading_weights_model")

tt_lib.device.Synchronize()
logger.info("Loaded TT model weights")
profiler.end(f"loading_weights")

logger.info("Tokenizing inputs")
profiler.start(f"tokenizing_inputs")
Expand All @@ -145,13 +151,12 @@ def run_falcon_demo_kv(

### First prefill run with compile ###
logger.info("Running 1st run prefill stage with compile...")
profiler.start(f"first_run_prefill_stage_compile", force_enable=True)
post_processor = partial(post_process)
use_cache = True
output_ids = torch.zeros(num_users, 1, dtype=torch.int64)
time_prefill_compile = 0
for user_id in range(num_users):
prefill_wc_start = time.time()

time_prefill_compile_start = time.time()
(
tt_prefill_embeddings,
tt_prefill_attention_mask,
Expand All @@ -169,6 +174,8 @@ def run_falcon_demo_kv(
layer_past_len=0,
use_cache=use_cache,
)
time_prefill_compile_end = time.time()
time_prefill_compile += time_prefill_compile_end - time_prefill_compile_start

tt_prefill_embeddings.deallocate()
if tt_prefill_attention_mask is not None:
Expand All @@ -180,26 +187,24 @@ def run_falcon_demo_kv(
user_output_ids = post_processor(logits=logits, index=num_input_tokens - 1)
output_ids[user_id] = user_output_ids

prefill_wc_end = time.time()

generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1)

tt_lib.device.Synchronize()
logger.info("Finished 1st run prefill stage with compile")
profiler.end(f"first_run_prefill_stage_compile", force_enable=True)

### First run decode stage with compile ###
logger.info("Running 1st run decode stage with compile...")
profiler.start(f"first_run_decode_stage_compile", force_enable=True)
decode_ids = torch.zeros(batch_size, 1, dtype=torch.int64)

for user_id, output_id in enumerate(output_ids):
decode_ids[user_id] = output_id

kv_cache_len = num_input_tokens # This will increment by one after each decode
prompt_is_done = [False for _ in range(num_users)]
for output_token_index in range(max_seq_len - num_input_tokens):
decode_wc_start = time.time()

time_decode_compile = 0
for output_token_index in range(max_seq_len - num_input_tokens):
time_decode_compile_start = time.time()
(
tt_decode_embeddings,
tt_decode_attention_mask,
Expand All @@ -214,6 +219,9 @@ def run_falcon_demo_kv(
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
time_decode_compile_end = time.time()
time_decode_compile += time_decode_compile_end - time_decode_compile_start

tt_decode_embeddings.deallocate()
if tt_decode_attention_mask is not None:
tt_decode_attention_mask.deallocate()
Expand All @@ -235,11 +243,8 @@ def run_falcon_demo_kv(
generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)
kv_cache_len += 1

decode_wc_end = time.time()

tt_lib.device.Synchronize()
logger.info("Finished 1st run decode stage with compile")
profiler.end(f"first_run_decode_stage_compile", force_enable=True)
tt_lib.device.Synchronize()

del user_output_ids
del output_ids
Expand All @@ -256,15 +261,13 @@ def run_falcon_demo_kv(
profiler.enable()
enable_persistent_kernel_cache()

logger.info("Running inference prefill stage...")
profiler.start(f"second_run_prefill_stage", force_enable=True)

post_processor = partial(post_process)
use_cache = True
output_ids = torch.zeros(num_users, 1, dtype=torch.int64)
logger.info("Running inference prefill stage...")
time_prefill_inference = 0
for user_id in range(num_users):
prefill_start = time.time()

time_prefill_inference_start = time.time()
(
tt_prefill_embeddings,
tt_prefill_attention_mask,
Expand All @@ -282,6 +285,8 @@ def run_falcon_demo_kv(
layer_past_len=0,
use_cache=use_cache,
)
time_prefill_inference_end = time.time()
time_prefill_inference += time_prefill_inference_end - time_prefill_inference_start

tt_prefill_embeddings.deallocate()
if tt_prefill_attention_mask is not None:
Expand All @@ -293,26 +298,25 @@ def run_falcon_demo_kv(
user_output_ids = post_processor(logits=logits, index=num_input_tokens - 1)
output_ids[user_id] = user_output_ids

prefill_end = time.time()
logger.info("Finished inference prefill stage")

generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1)

logger.info("Finished inference prefill stage")
profiler.end(f"second_run_prefill_stage", force_enable=True)
profiler.disable()

### Inference run decode ###
logger.info("Running inference decode stage...")
profiler.start(f"second_run_decode_stage", force_enable=True)

decode_ids = torch.zeros(batch_size, 1, dtype=torch.int64)
for user_id, output_id in enumerate(output_ids):
decode_ids[user_id] = output_id

kv_cache_len = num_input_tokens # This will increment by one after each decode
prompt_is_done = [False for _ in range(num_users)]
for output_token_index in range(max_seq_len - num_input_tokens):
decode_start = time.time()

time_decode_inference = 0
for output_token_index in range(max_seq_len - num_input_tokens):
time_decode_inference_start = time.time()
(
tt_decode_embeddings,
tt_decode_attention_mask,
Expand All @@ -327,6 +331,9 @@ def run_falcon_demo_kv(
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
time_decode_inference_end = time.time()
time_decode_inference += time_decode_inference_end - time_decode_inference_start

tt_decode_embeddings.deallocate()
if tt_decode_attention_mask is not None:
tt_decode_attention_mask.deallocate()
Expand All @@ -348,10 +355,8 @@ def run_falcon_demo_kv(
generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)
kv_cache_len += 1

decode_end = time.time()

logger.info("Finished inference decode stage")
profiler.end(f"second_run_decode_stage", force_enable=True)
logger.info(f"Total number of tokens generated in decode: {batch_size*(kv_cache_len)}")

print_output_prompts(generated_ids, tokenizer)

Expand All @@ -361,29 +366,35 @@ def run_falcon_demo_kv(

measurements = {
"preprocessing": profiler.get("tokenizing_inputs"),
"loading_weights_model": profiler.get("loading_weights_model"),
"initializing_KV_cache": profiler.get("initializing_KV_cache"),
"compile_prefill": profiler.get("first_run_prefill_stage_compile") - profiler.get("second_run_prefill_stage"),
"compile_decode": profiler.get("first_run_decode_stage_compile") - profiler.get("second_run_decode_stage"),
"compile_total": profiler.get("first_run_prefill_stage_compile")
- profiler.get("second_run_prefill_stage")
+ profiler.get("first_run_decode_stage_compile")
- profiler.get("second_run_decode_stage"),
"inference_prefill": profiler.get("second_run_prefill_stage"),
"inference_decode": profiler.get("second_run_decode_stage"),
"inference_total": profiler.get("second_run_prefill_stage") + profiler.get("second_run_decode_stage"),
"inference_throughput": (batch_size * output_token_index)
/ (profiler.get("second_run_prefill_stage") + profiler.get("second_run_decode_stage")),
"compile_prefill": time_prefill_compile - time_prefill_inference,
"compile_decode": time_decode_compile - time_decode_inference,
"compile_total": time_prefill_compile - time_prefill_inference + time_decode_compile - time_decode_inference,
"inference_prefill": time_prefill_inference,
"inference_decode": time_decode_inference,
"inference_total": time_prefill_inference + time_decode_inference,
"inference_throughput_prefill": num_users / time_prefill_inference,
"inference_throughput_decode": batch_size / time_decode_inference,
}

logger.info(f"pre processing duration: {measurements['preprocessing']} s")
logger.info(f"initializing KV cache duration: {measurements['initializing_KV_cache']} s")
logger.info(f"prefill compile time: {measurements['compile_prefill']} s")
logger.info(f"decode compile time: {measurements['compile_decode']} s")
logger.info(f"total compile time: {measurements['compile_total']} s")
logger.info(f"prefill inference time: {measurements['inference_prefill']} s")
logger.info(f"decode inference time: {measurements['inference_decode']} s")
logger.info(f"total inference time: {measurements['inference_total']} s")
logger.info(f"inference throughput: {measurements['inference_throughput']} inp/s")
logger.info(f"pre processing duration: {round(measurements['preprocessing'], 5)} s")
logger.info(f"loading weights and model duration: {round(measurements['loading_weights_model'], 5)} s")
logger.info(f"initializing KV cache duration: {round(measurements['initializing_KV_cache'], 5)} s")
logger.info(f"prefill compile time: {round(measurements['compile_prefill'],5)} s")
logger.info(f"decode compile time: {round(measurements['compile_decode'], 5)} s")
logger.info(f"total compile time: {round(measurements['compile_total'], 5)} s")
logger.info(f"prefill inference time: {round(measurements['inference_prefill'], 5)} s")
logger.info(f"decode inference time: {round(measurements['inference_decode'], 5)} s")
logger.info(f"total inference time: {round(measurements['inference_total'], 5)} s")
logger.info(f"inference throughput prefill: {round(measurements['inference_throughput_prefill'], 5)} 1/s")
logger.info(
f"inference throughput prefill | seq_len={num_input_tokens}: {round(measurements['inference_throughput_prefill']*num_input_tokens, 5)} tok/sec"
)
logger.info(f"inference throughput decode: {round(measurements['inference_throughput_decode'], 5)} 1/s")
logger.info(
f"end-to-end throughput | seq_len={num_input_tokens}: {round((batch_size*(kv_cache_len)/(measurements['preprocessing']+measurements['loading_weights_model']+ measurements['initializing_KV_cache'] + measurements['compile_total'] + measurements['inference_total']))/num_users, 5)} tok/sec/user"
)

return generated_text, measurements

Expand Down
7 changes: 0 additions & 7 deletions models/demos/falcon7b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,6 @@ def run_test_FalconCausalLM_end_to_end(
logger.info(f"falcon {comment} inference time: {second_iter_time}")
logger.info(f"falcon {comment} compile time: {compile_time}")

if does_pass:
logger.info("Falcon CausalLM Passed!")
else:
logger.warning("Falcon CausalLM Failed!")
# TODO: Fix PCC for decode and uncomment this
# assert does_pass, f"PCC value is lower than {pcc}"


@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
Expand Down

0 comments on commit a71ef03

Please sign in to comment.