Skip to content

Commit

Permalink
Remove export test and add as tool
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Oct 25, 2024
1 parent 6190176 commit 9fe2c40
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 177 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ jobs:
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
- name: Export mlir and vmfb
run: pytest -v -s sharktank/tests/evaluate/export_artifacts_test.py --bs 4
- name: Run perplexity test with vmfb
run: pytest -n 4 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --iree-device='hip://7' --longrun

Expand Down
72 changes: 61 additions & 11 deletions sharktank/sharktank/evaluate/perplexity_vmfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sharktank.utils.vmfb_runner import *
from sharktank.utils.load_llm import *
from sharktank.utils.create_cache import *
from sharktank.utils.export_artifacts import *

log_levels = {
"info": logging.INFO,
Expand All @@ -58,14 +59,24 @@ class Perplexity:
"""

def __init__(
self, torch_device, iree_device, kv_cache_type, tensor_parallelism_size
self,
torch_device,
iree_device,
iree_hip_target,
iree_hal_target_backends,
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
):
self.torch_device = torch_device
self.iree_device = iree_device
self.iree_hip_target = iree_hip_target
self.iree_hal_target_backends = iree_hal_target_backends
self.kv_cache_type = kv_cache_type
self.activation_dtype = torch.float32
self.attention_dtype = torch.float32
self.tensor_parallelism_size = tensor_parallelism_size
self.attention_kernel = attention_kernel

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -102,6 +113,19 @@ def print_token_comparison(self, i):
logger.debug(f"{expected_token}")
logger.debug(f"{expected_token_id}")

@timeit
def compile_model(self, weight_path_str):
export_artifacts = ExportArtifacts(
irpa_path=weight_path_str,
batch_size=self.bs,
iree_hip_target=self.iree_hip_target,
iree_hal_target_backends=self.iree_hal_target_backends,
attention_kernel=self.attention_kernel,
tensor_parallelism_size=self.tensor_parallelism_size,
)
vmfb_path = export_artifacts.get_artifacts()
return vmfb_path

@timeit
def load_model(self, weight_path, tokenizer, vmfb_path, weight_path_str):

Expand Down Expand Up @@ -130,6 +154,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path, weight_path_str):

self.generator = TorchGenerator(model, tokenizer)

self.weight_path_str = weight_path_str
self.runner = vmfbRunner(
device=self.iree_device,
vmfb_path=vmfb_path,
Expand All @@ -151,10 +176,12 @@ def get_prompts(self):
s.replace("\n", "").rstrip()
for s in test_prompts
if s != "" and len(s.split()) >= 20 and s.count("=") < 2
]
][0:4]

logger.info(f" num_test_prompts: {len(test_prompts)}")

self.bs = len(test_prompts)

return test_prompts

def prefill_vmfb(self, token_batch, i):
Expand Down Expand Up @@ -253,8 +280,6 @@ def get_logits(self):
(self.token_ids != 0).int().detach().clone().to(self.torch_device)
)

self.bs = len(self.test_prompts)

is_first_token = True
start = 0
for i in tqdm(
Expand Down Expand Up @@ -313,6 +338,7 @@ def compute_perplexity(self):
def get_perplexity(self, test_prompts):

self.test_prompts = test_prompts

self.get_logits()

self.out_logits = self.out_logits[..., :-1, :].contiguous()
Expand All @@ -331,25 +357,32 @@ def get_perplexity(self, test_prompts):


def run_perplexity(
vmfb_path,
weight_path,
weight_path_str,
tokenizer,
torch_device,
iree_device,
iree_hip_target,
iree_hal_target_backends,
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
):
perplexity = Perplexity(
torch_device=torch_device,
iree_device=iree_device,
iree_hip_target=iree_hip_target,
iree_hal_target_backends=iree_hal_target_backends,
kv_cache_type=kv_cache_type,
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=attention_kernel,
)

perplexity.load_model(weight_path, tokenizer, vmfb_path, weight_path_str)
test_prompts = perplexity.get_prompts()
ppl = perplexity.get_perplexity(test_prompts=test_prompts)

vmfb_path = perplexity.compile_model(weight_path_str)
perplexity.load_model(weight_path, tokenizer, vmfb_path, weight_path_str)
ppl = perplexity.get_perplexity(test_prompts)

return ppl

Expand All @@ -359,7 +392,24 @@ def main(argv):
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--torch-device", help="Torch device (or default)")
parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'")
parser.add_argument("--vmfb-path", help="Path to vmfb file")
parser.add_argument(
"--iree-hip-target",
action="store",
default="gfx942",
help="Specify the iree-hip target version (e.g., gfx942)",
)
parser.add_argument(
"--iree-hal-target-backends",
action="store",
default="rocm",
help="Specify the iree-hal target backends (e.g., rocm)",
)
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
Expand All @@ -376,19 +426,19 @@ def main(argv):
kv_cache_type = args.kv_cache_type
weight_path = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)

vmfb_path = args.vmfb_path
weight_path_str = str(args.irpa_file)

ppl = run_perplexity(
vmfb_path=vmfb_path,
weight_path=weight_path,
weight_path_str=weight_path_str,
tokenizer=tokenizer,
torch_device=torch_device,
iree_device=iree_device,
iree_hip_target=args.iree_hip_target,
iree_hal_target_backends=args.iree_hal_target_backends,
kv_cache_type=kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
attention_kernel=args.attention_kernel,
)

logger.info(f"\n{json.dumps(ppl, indent=2)}")
Expand Down
135 changes: 135 additions & 0 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
from pathlib import Path
import subprocess
import logging

import iree.compiler as ireec

logger = logging.getLogger("eval")

logger.setLevel(logging.INFO)

logger.root.handlers[0].setFormatter(
logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
)


class ExportArtifacts:
def __init__(
self,
irpa_path: str,
batch_size: int,
iree_hip_target: str,
iree_hal_target_backends: str,
attention_kernel: str,
tensor_parallelism_size: int,
):
self.sharktank_dir = str(
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent
)
self.irpa_path = irpa_path
self.batch_size = batch_size
self.iree_hip_target = iree_hip_target
self.iree_hal_target_backends = iree_hal_target_backends
self.attention_kernel = attention_kernel
self.tensor_parallelism_size = tensor_parallelism_size

def export_to_mlir(
self,
mlir_path: str,
json_path: str,
):
export_args = [
"python3",
"-m",
"sharktank.examples.export_paged_llm_v1",
"--irpa-file",
str(self.irpa_path),
"--output-mlir",
mlir_path,
"--output-config",
json_path,
"--bs",
str(self.batch_size),
]
if self.attention_kernel == "decomposed":
export_args.append("--attention-kernel")
export_args.append(self.attention_kernel)
elif self.attention_kernel == "torch_sdpa":
raise NotImplementedError("attention_kernel torch_sdpa not implemented yet")
if self.tensor_parallelism_size:
export_args.append("--tensor-parallelism-size")
export_args.append(str(self.tensor_parallelism_size))

cmd = subprocess.list2cmdline(export_args)

logger.info(
f"export_args: {export_args}\n self.sharktank_dir: {self.sharktank_dir}"
)

cwd = self.sharktank_dir + "/sharktank"

logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}")
proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd)
return_code = proc.returncode
if return_code != 0:
logger.error("Error exporting mlir: ", return_code)

def compile_to_vmfb(
self,
mlir_path,
vmfb_path,
):
compile_flags = ["--iree-hip-target=" + self.iree_hip_target]

ireec.compile_file(
input_file=mlir_path,
target_backends=[self.iree_hal_target_backends],
extra_args=compile_flags,
output_file=vmfb_path,
)

def create_file(self, suffix, prefix):
file_path = Path(prefix).with_suffix(suffix)
f = open(file_path, "w")
return file_path

def get_artifacts(self):

self.dir_path = self.sharktank_dir + "/" + "tmp_perplexity_ci_artifacts/"
temp_dir = Path(self.dir_path)
temp_dir.mkdir(parents=True, exist_ok=True)

model_name = (
str(self.irpa_path).split("/")[-1].split(".")[0]
+ "_"
+ self.attention_kernel
)
mlir_path = str(
self.create_file(suffix=".mlir", prefix=self.dir_path + model_name)
)
json_path = str(
self.create_file(suffix=".json", prefix=self.dir_path + model_name)
)
vmfb_path = str(
self.create_file(suffix=".vmfb", prefix=self.dir_path + model_name)
)

if self.attention_kernel == "decomposed":
self.export_to_mlir(
mlir_path=mlir_path,
json_path=json_path,
)

self.compile_to_vmfb(
mlir_path=mlir_path,
vmfb_path=vmfb_path,
)

return vmfb_path
Loading

0 comments on commit 9fe2c40

Please sign in to comment.