Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add eval for attention sink #7070

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
"fbsource//third-party/pypi/tqdm:tqdm",
"fbsource//third-party/pypi/datasets:datasets",
"fbsource//third-party/pypi/lm-eval:lm-eval",
"fbsource//third-party/pypi/tiktoken:tiktoken",
":export_library",
Expand Down
11 changes: 9 additions & 2 deletions examples/models/llama/eval_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import torch

from .eval_llama_lib import build_args_parser, eval_llama
from .eval_llama_lib import (
build_args_parser,
eval_llama,
eval_llama_with_attention_sink,
)

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand All @@ -24,7 +28,10 @@ def main() -> None:
args = parser.parse_args()
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True
eval_llama(modelname, args) # pyre-ignore
if args.use_attention_sink:
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
else:
eval_llama(modelname, args) # pyre-ignore


if __name__ == "__main__":
Expand Down
64 changes: 64 additions & 0 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Optional, Union

import torch

from datasets import load_dataset
from executorch.examples.models.llama.export_llama_lib import (
get_quantizer_and_quant_params,
)
Expand All @@ -21,6 +23,8 @@
)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from lm_eval.evaluator import simple_evaluate
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from .evaluate.eager_eval import EagerEvalWrapper

Expand Down Expand Up @@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
)

# Set of parameters secpific to AttentionSink.
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)

return parser


Expand Down Expand Up @@ -309,3 +316,60 @@ def eval_llama(

for task, res in eval_results["results"].items():
print(f"{task}: {res}")


def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
"""
Evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
"""
assert args.use_attention_sink is not None # pyre-ignore [16]
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
attention_sink_params = args.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(args)
model = manager.model.eval().to(device=device)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

nlls = []
loss_fn = CrossEntropyLoss(reduction="none")
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
input_pos = 0
while input_pos < args.attention_sink_eval_tokens:
for text in eval_data["text"]: # pyre-ignore [16]
tokens = tokenizer.encode(text, bos=False, eos=False)
if len(tokens) <= 0:
continue
with torch.no_grad():
num_tokens = min(
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
)
logits = model(
torch.tensor(
[tokens[:num_tokens]], dtype=torch.int64, device=device
),
torch.tensor([input_pos], dtype=torch.int64, device=device),
).squeeze(dim=0)
neg_log_likelihood = loss_fn(
logits,
torch.tensor(
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
).view(-1),
)
nlls.append(neg_log_likelihood)
input_pos += num_tokens
progress_bar.update(num_tokens)
if input_pos >= args.attention_sink_eval_tokens:
break
ppl = torch.exp(torch.cat(nlls).mean())
print(f"Perplexity: {ppl.item()}")
return ppl.item()