From 2471a3cbf613ab34e6706ccddb0a8a11670ed9b4 Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Thu, 6 Jun 2024 20:42:28 +0000 Subject: [PATCH] #5383: [Falcon7b] Add option to run huggingface model in perplexity test, and add perplexity test to demo ci Signed-off-by: Salar Hosseini --- .../falcon7b/tests/test_perplexity_falcon.py | 202 +++++++++++------- tests/scripts/t3000/run_t3000_demo_tests.sh | 4 + 2 files changed, 132 insertions(+), 74 deletions(-) diff --git a/models/demos/falcon7b/tests/test_perplexity_falcon.py b/models/demos/falcon7b/tests/test_perplexity_falcon.py index b777078d333..1851ac5431a 100644 --- a/models/demos/falcon7b/tests/test_perplexity_falcon.py +++ b/models/demos/falcon7b/tests/test_perplexity_falcon.py @@ -8,6 +8,7 @@ from transformers import AutoTokenizer from tqdm import tqdm import time +import ttnn from models.demos.falcon7b.tt.falcon_causallm import TtFalconCausalLM from models.demos.falcon7b.tt.model_config import get_model_config from models.demos.falcon7b.tests.test_utils import initialize_kv_cache, load_hf_model @@ -15,8 +16,8 @@ from models.utility_functions import is_wormhole_b0, get_devices_for_t3000, tt_tensors_to_torch_tensors -def calculate_perplexity(tt_FalconCausalLM, dataloader, llm_mode, batch_size, seq_len, kv_cache, configuration): - if llm_mode == "prefill": +def calculate_perplexity(model, dataloader, llm_mode, batch_size, seq_len, kv_cache, configuration, use_hf_model=False): + if llm_mode == "prefill" and not use_hf_model: assert batch_size == 1 use_cache = True loss_func = torch.nn.CrossEntropyLoss() @@ -24,64 +25,79 @@ def calculate_perplexity(tt_FalconCausalLM, dataloader, llm_mode, batch_size, se with torch.no_grad(): for input_ids, labels in tqdm(dataloader, desc="Evaluating batches"): if llm_mode == "prefill": - user_id = 0 - ( - tt_prefill_input_ids, - tt_prefill_attention_mask, - ) = tt_FalconCausalLM.model_preprocessing( - "prefill", input_ids[user_id::batch_size], 0, num_input_tokens=seq_len - ) - tt_logits, kv_cache = tt_FalconCausalLM( - input_ids=tt_prefill_input_ids, - llm_mode="prefill", - attention_mask=tt_prefill_attention_mask, - user_id=user_id, - layer_past=kv_cache, - layer_past_len=0, - use_cache=use_cache, - ) - # Get outputs from all devices - tt_logits = torch.concat( - [tt_out_torch.squeeze(1) for tt_out_torch in tt_tensors_to_torch_tensors(tt_logits)] - ) - loss = loss_func(tt_logits.view(batch_size * seq_len, configuration.vocab_size), labels.view(-1)) - nlls.append(loss.float()) - # Deallocate tt tensors - for i in range(len(tt_logits)): - tt_prefill_input_ids[i].deallocate() - tt_prefill_attention_mask[i].deallocate() - tt_logits[i].deallocate() - elif llm_mode == "decode": - output_logits = [] - for kv_cache_len in tqdm(range(seq_len), desc="Decoding tokens for current batch"): - decode_ids = input_ids[:, kv_cache_len].view(batch_size, 1) + if not use_hf_model: + user_id = 0 ( - tt_decode_input_ids, - tt_decode_attention_mask, - ) = tt_FalconCausalLM.model_preprocessing( - "decode", decode_ids, kv_cache_len, num_input_tokens=kv_cache_len + 1 + tt_prefill_input_ids, + tt_prefill_attention_mask, + ) = model.model_preprocessing( + "prefill", input_ids[user_id::batch_size], 0, num_input_tokens=seq_len ) - tt_logits, kv_cache = tt_FalconCausalLM( - input_ids=tt_decode_input_ids, - llm_mode="decode", - attention_mask=tt_decode_attention_mask, + tt_logits, kv_cache = model( + input_ids=tt_prefill_input_ids, + llm_mode="prefill", + attention_mask=tt_prefill_attention_mask, + user_id=user_id, layer_past=kv_cache, - layer_past_len=kv_cache_len, + layer_past_len=0, use_cache=use_cache, ) # Get outputs from all devices logits = torch.concat( - [torch_logit.squeeze(1) for torch_logit in tt_tensors_to_torch_tensors(tt_logits)], dim=-2 + [tt_out_torch.squeeze(1) for tt_out_torch in tt_tensors_to_torch_tensors(tt_logits)] ) - output_logits.append(logits.view(-1, 1, configuration.vocab_size)) # Deallocate tt tensors for i in range(len(tt_logits)): - tt_decode_input_ids[i].deallocate() - tt_decode_attention_mask[i].deallocate() + tt_prefill_input_ids[i].deallocate() + if isinstance(tt_prefill_attention_mask[i], ttnn.experimental.tensor.Tensor): + tt_prefill_attention_mask[i].deallocate() + elif isinstance(tt_prefill_attention_mask[i], list): + for tt_attention_mask_element in tt_prefill_attention_mask[i]: + tt_attention_mask_element.deallocate() tt_logits[i].deallocate() - output_logits = torch.cat(output_logits, dim=1) - loss = loss_func(output_logits.view(batch_size * seq_len, configuration.vocab_size), labels.view(-1)) - nlls.append(loss.float()) + else: # huggingface model + logits, _ = model(input_ids=input_ids, use_cache=use_cache, return_dict=False) + + elif llm_mode == "decode": + logits = [] + layer_present = None + for kv_cache_len in tqdm(range(seq_len), desc="Decoding tokens for current batch"): + decode_ids = input_ids[:, kv_cache_len].view(batch_size, 1) + if not use_hf_model: + ( + tt_decode_input_ids, + tt_decode_attention_mask, + ) = model.model_preprocessing( + "decode", decode_ids, kv_cache_len, num_input_tokens=kv_cache_len + 1 + ) + tt_logits, kv_cache = model( + input_ids=tt_decode_input_ids, + llm_mode="decode", + attention_mask=tt_decode_attention_mask, + layer_past=kv_cache, + layer_past_len=kv_cache_len, + use_cache=use_cache, + ) + # Get outputs from all devices + logits_cur = torch.concat( + [torch_logit.squeeze(1) for torch_logit in tt_tensors_to_torch_tensors(tt_logits)], dim=-2 + ) + logits.append(logits_cur.view(-1, 1, configuration.vocab_size)) + # Deallocate tt tensors + for i in range(len(tt_logits)): + tt_decode_input_ids[i].deallocate() + tt_decode_attention_mask[i].deallocate() + tt_logits[i].deallocate() + else: # huggingface model + logits_cur, layer_present = model( + input_ids=decode_ids, past_key_values=layer_present, use_cache=use_cache, return_dict=False + ) + logits.append(logits_cur) + + logits = torch.cat(logits, dim=1) + + loss = loss_func(logits.view(batch_size * seq_len, configuration.vocab_size), labels.view(-1)) + nlls.append(loss.float()) nll = torch.stack(nlls).mean() ppl = torch.exp(nll) @@ -104,6 +120,7 @@ def run_test_perplexity( dataset_name="wikitext", dataset_config="wikitext-2-raw-v1", split="test", + use_hf_model=False, ): # Set random reproducible seed torch.manual_seed(0) @@ -113,26 +130,6 @@ def run_test_perplexity( hugging_face_reference_model, state_dict = load_hf_model(model_location_generator, model_version) configuration = hugging_face_reference_model.config - # Load tt-metal model config - model_config = get_model_config(model_config_str, max_seq_len) - tt_cache_path = get_tt_cache_path( - model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"] - ) - - # Load tt-metal model - logger.info("Moving weights (all layers) to device; might take some time...") - tt_FalconCausalLM = TtFalconCausalLM( - devices, - state_dict, - "", - num_layers, - configuration, - max_seq_len, - model_config, - tt_cache_path, - max_seq_len, - ) - # Prepare dataset logger.info("Preparing dataset...") dataset = prepare_textgen_dataset(dataset_name, dataset_config, split) @@ -140,15 +137,39 @@ def run_test_perplexity( encodings = tokenizer(dataset, return_tensors="pt")["input_ids"].squeeze(0) dataloader = prepare_textgen_dataloader(encodings, batch_size, max_seq_len, num_samples, stride) - # Initialize kvcache - logger.info("Initializing kvcache...") - kv_cache = initialize_kv_cache(configuration, num_layers, batch_size, max_seq_len, devices) + if not use_hf_model: + # Load tt-metal model config + model_config = get_model_config(model_config_str, max_seq_len) + tt_cache_path = get_tt_cache_path( + model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"] + ) + + # Load tt-metal model + logger.info("Moving weights (all layers) to device; might take some time...") + model = TtFalconCausalLM( + devices, + state_dict, + "", + num_layers, + configuration, + max_seq_len, + model_config, + tt_cache_path, + max_seq_len, + ) + + # Initialize kvcache + logger.info("Initializing kvcache...") + kv_cache = initialize_kv_cache(configuration, num_layers, batch_size, max_seq_len, devices) + else: + model = hugging_face_reference_model + kv_cache = None # Evaluate perplexity logger.info("Evaluating perplexity...") start = time.time() nll, ppl = calculate_perplexity( - tt_FalconCausalLM, dataloader, llm_mode, batch_size, max_seq_len, kv_cache, configuration + model, dataloader, llm_mode, batch_size, max_seq_len, kv_cache, configuration, use_hf_model=use_hf_model ) logger.info(f"Perplexity evaluation time: {(time.time() - start):.2f} s") logger.info(f"Negative log-likelihood: {nll:.4f}") @@ -161,10 +182,43 @@ def run_test_perplexity( logger.info("Falcon Perplexity Check Passed!") +@pytest.mark.parametrize( + "llm_mode, batch_size, max_seq_len, num_samples, expected_ppl", + ( + ("prefill", 32, 1024, 64, 11.5), + ("decode", 64, 1024, 64, 11.5), + ), + ids=[ + "prefill_seq1024", + "decode_1024", + ], +) +def test_perplexity_huggingface( + llm_mode, + batch_size, + max_seq_len, + num_samples, # Total number of prompts to evaluate (all if None) + expected_ppl, + model_location_generator, +): + run_test_perplexity( + llm_mode, + batch_size, + max_seq_len, + None, + model_location_generator, + None, + None, + num_samples, + expected_ppl, + use_hf_model=True, + ) + + @pytest.mark.parametrize( "llm_mode, batch_size, max_seq_len, model_config_str, num_samples, expected_ppl", ( - ("prefill", 1, 1024, "BFLOAT16-DRAM", 128, 11.5), + ("prefill", 1, 1024, "BFLOAT16-DRAM", 64, 12.0), ("decode", 32, 1024, "BFLOAT16-L1_SHARDED", 64, 12.5), ), ids=[ diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index 5dca3e93a87..d7ac413ff9d 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -30,6 +30,10 @@ run_t3000_falcon7b_tests(){ WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/t3000/falcon7b/input_data_t3000.json' models/demos/t3000/falcon7b/demo_t3000.py::test_demo_multichip[user_input0-8-True-perf_mode_stochastic_verify] WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest --disable-warnings -q -s --input-method=json --input-path='models/demos/t3000/falcon7b/input_data_t3000.json' models/demos/t3000/falcon7b/demo_t3000.py::test_demo_multichip[user_input0-8-True-default_mode_greedy_verify] + # Falcon7B perplexity test (prefill and decode) + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b/tests/test_perplexity_falcon.py::test_perplexity[1-True-prefill_seq1024_dram] + # WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b/tests/test_perplexity_falcon.py::test_perplexity[1-True-decode_1024_l1_sharded] # Disabled due to Issue #9268 + # Record the end time end_time=$(date +%s) duration=$((end_time - start_time))