diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 284520d4d5..445bcd673b 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -150,6 +150,8 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + "fbsource//third-party/pypi/tqdm:tqdm", + "fbsource//third-party/pypi/datasets:datasets", "fbsource//third-party/pypi/lm-eval:lm-eval", "fbsource//third-party/pypi/tiktoken:tiktoken", ":export_library", diff --git a/examples/models/llama/eval_llama.py b/examples/models/llama/eval_llama.py index 09157789bd..7c959d08b9 100644 --- a/examples/models/llama/eval_llama.py +++ b/examples/models/llama/eval_llama.py @@ -10,7 +10,11 @@ import torch -from .eval_llama_lib import build_args_parser, eval_llama +from .eval_llama_lib import ( + build_args_parser, + eval_llama, + eval_llama_with_attention_sink, +) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -24,7 +28,10 @@ def main() -> None: args = parser.parse_args() # Overrides this arg, because evaluation requires full logits. args.generate_full_logits = True - eval_llama(modelname, args) # pyre-ignore + if args.use_attention_sink: + eval_llama_with_attention_sink(modelname, args) # pyre-ignore + else: + eval_llama(modelname, args) # pyre-ignore if __name__ == "__main__": diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index dd01365ba5..a7f0f88cd9 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -10,6 +10,8 @@ from typing import Optional, Union import torch + +from datasets import load_dataset from executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params, ) @@ -21,6 +23,8 @@ ) from executorch.extension.llm.tokenizer.utils import get_tokenizer from lm_eval.evaluator import simple_evaluate +from torch.nn import CrossEntropyLoss +from tqdm import tqdm from .evaluate.eager_eval import EagerEvalWrapper @@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser: help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) + # Set of parameters secpific to AttentionSink. + parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) + return parser @@ -309,3 +316,60 @@ def eval_llama( for task, res in eval_results["results"].items(): print(f"{task}: {res}") + + +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): + """ + Evaluate the model's perplexity when AttentionSink is enabled. + + This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + """ + assert args.use_attention_sink is not None # pyre-ignore [16] + assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] + attention_sink_params = args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + + assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] + + device = "cuda" if torch.cuda.is_available() else "cpu" + manager: LLMEdgeManager = _prepare_for_llama_export(args) + model = manager.model.eval().to(device=device) + tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] + + eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + nlls = [] + loss_fn = CrossEntropyLoss(reduction="none") + progress_bar = tqdm(total=args.attention_sink_eval_tokens) + input_pos = 0 + while input_pos < args.attention_sink_eval_tokens: + for text in eval_data["text"]: # pyre-ignore [16] + tokens = tokenizer.encode(text, bos=False, eos=False) + if len(tokens) <= 0: + continue + with torch.no_grad(): + num_tokens = min( + len(tokens) - 1, args.attention_sink_eval_tokens - input_pos + ) + logits = model( + torch.tensor( + [tokens[:num_tokens]], dtype=torch.int64, device=device + ), + torch.tensor([input_pos], dtype=torch.int64, device=device), + ).squeeze(dim=0) + neg_log_likelihood = loss_fn( + logits, + torch.tensor( + [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device + ).view(-1), + ) + nlls.append(neg_log_likelihood) + input_pos += num_tokens + progress_bar.update(num_tokens) + if input_pos >= args.attention_sink_eval_tokens: + break + ppl = torch.exp(torch.cat(nlls).mean()) + print(f"Perplexity: {ppl.item()}") + return ppl.item()