Skip to content

Commit

Permalink
add eval for attention sink
Browse files Browse the repository at this point in the history
Pull Request resolved: #7070

This PR adds the function to evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py which is used by the AttentionSink paper to evaluate the model's perplexity when AttentionSink is enabled.
ghstack-source-id: 256083503
@exported-using-ghexport

Differential Revision: [D66474732](https://our.internmc.facebook.com/intern/diff/D66474732/)
  • Loading branch information
helunwencser committed Dec 2, 2024
1 parent 70af372 commit e66fa80
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
2 changes: 2 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
"fbsource//third-party/pypi/tqdm:tqdm",
"fbsource//third-party/pypi/datasets:datasets",
"fbsource//third-party/pypi/lm-eval:lm-eval",
"fbsource//third-party/pypi/tiktoken:tiktoken",
":export_library",
Expand Down
11 changes: 9 additions & 2 deletions examples/models/llama/eval_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import torch

from .eval_llama_lib import build_args_parser, eval_llama
from .eval_llama_lib import (
build_args_parser,
eval_llama,
eval_llama_with_attention_sink,
)

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand All @@ -24,7 +28,10 @@ def main() -> None:
args = parser.parse_args()
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True
eval_llama(modelname, args) # pyre-ignore
if args.use_attention_sink:
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
else:
eval_llama(modelname, args) # pyre-ignore


if __name__ == "__main__":
Expand Down
66 changes: 66 additions & 0 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Optional, Union

import torch

from datasets import load_dataset
from executorch.examples.models.llama.export_llama_lib import (
get_quantizer_and_quant_params,
)
Expand All @@ -21,6 +23,8 @@
)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from lm_eval.evaluator import simple_evaluate
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from .evaluate.eager_eval import EagerEvalWrapper

Expand Down Expand Up @@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
)

# Set of parameters secpific to AttentionSink.
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)

return parser


Expand Down Expand Up @@ -309,3 +316,62 @@ def eval_llama(

for task, res in eval_results["results"].items():
print(f"{task}: {res}")


def eval_llama_with_attention_sink(
model_name: str, args: argparse.ArgumentParser
):
"""
Evaluate the model's perplexity when AttentionSink is enabled.
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
"""
assert args.use_attention_sink is not None # pyre-ignore [16]
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
attention_sink_params = args.use_attention_sink.split(",") # pyre-ignore [16]
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(args)
model = manager.model.eval().to(device=device)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

nlls = []
loss_fn = CrossEntropyLoss(reduction="none")
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
input_pos = 0
while input_pos < args.attention_sink_eval_tokens:
for text in eval_data["text"]: # pyre-ignore [16]
tokens = tokenizer.encode(text, bos=False, eos=False)
if len(tokens) <= 0:
continue
with torch.no_grad():
num_tokens = min(
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
)
logits = model(
torch.tensor(
[tokens[:num_tokens]], dtype=torch.int64, device=device
),
torch.tensor([input_pos], dtype=torch.int64, device=device),
).squeeze(dim=0)
neg_log_likelihood = loss_fn(
logits,
torch.tensor(
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
).view(-1),
)
nlls.append(neg_log_likelihood)
input_pos += num_tokens
progress_bar.update(num_tokens)
if input_pos >= args.attention_sink_eval_tokens:
break
ppl = torch.exp(torch.cat(nlls).mean())
print(f"Perplexity: {ppl.item()}")
return ppl.item()

0 comments on commit e66fa80

Please sign in to comment.