From 861ead52a26528974878041617b0e09ee58528fe Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:49:32 -0700 Subject: [PATCH 1/3] [sharktank] Evaluation - Update timeout for Perplexity CI test (#305) Update default CI timeout from 6 to 10 hrs --- .github/workflows/ci_eval.yaml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index d3681d95a..9181d5b72 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -16,6 +16,7 @@ concurrency: jobs: test_perplexity: + timeout-minutes: 600 name: "Evaluation Tests - perplexity" strategy: matrix: @@ -59,9 +60,3 @@ jobs: - name: Run perplexity test run: pytest sharktank/tests/evaluate/perplexity_test.py --longrun - - - name: Update Perplexity baseline numbers - uses: actions/upload-artifact@v4 - with: - name: current_perplexity_scores_json - path: ${{ env.SHARK_PLATFORM_REPO_ROOT }}/sharktank/sharktank/evaluate/ From 16be3651032b12602a7a22d714351bb610f446d8 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Tue, 22 Oct 2024 15:10:14 -0700 Subject: [PATCH 2/3] Benchmark Llama 3.1 f16 and fp8 with CI (#284) Adds pytests for f16 and fp8 with the CI. Currently Llama 3.1 8B f16 is the only test that fully benchmarks through. Llama 3.1 8B fp8, Llama 3.1 70B f16/fp8, and Llama 3.1 405B f16/fp8 tests are marked as XFAIL for now. --------- Signed-off-by: aviator19941 --- .github/workflows/ci-llama.yaml | 78 ++ sharktank/conftest.py | 14 + .../sharktank/examples/export_paged_llm_v1.py | 15 + .../models/llama/benchmark_amdgpu_tests.py | 919 ++++++++++++++++++ 4 files changed, 1026 insertions(+) create mode 100644 .github/workflows/ci-llama.yaml create mode 100644 sharktank/tests/models/llama/benchmark_amdgpu_tests.py diff --git a/.github/workflows/ci-llama.yaml b/.github/workflows/ci-llama.yaml new file mode 100644 index 000000000..6dd9f7d68 --- /dev/null +++ b/.github/workflows/ci-llama.yaml @@ -0,0 +1,78 @@ +name: Llama Benchmarking Tests + +on: + workflow_dispatch: + schedule: + # Weekdays at 9:00 AM UTC = 2:00 AM PST. + - cron: "0 9 * * 1-5" + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_llama: + name: "Llama Benchmarking Tests" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: llama-mi300 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv + steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@v3 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@v3 + + - name: Cache Pip Packages + uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Try with the latest nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade \ + iree-compiler \ + iree-runtime \ + "numpy<2.0" + + - name: Run llama test + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_tests.py -v -s --longrun + + - name: Upload llama executable files + uses: actions/upload-artifact@v4 + with: + name: llama-files + path: ${{ github.workspace }}/${{ steps.date.outputs.date }} diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 040775409..a5583b711 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -128,6 +128,13 @@ def pytest_addoption(parser): help="Llama3.1 8B & 405B model baseline perplexity scores json", ) + parser.addoption( + "--iree-hip-target", + action="store", + default="gfx942", + help="Specify the iree-hip target version (e.g., gfx942)", + ) + def set_fixture_from_cli_option( request: FixtureRequest, @@ -168,6 +175,13 @@ def caching(request: FixtureRequest) -> Optional[bool]: return set_fixture_from_cli_option(request, "caching") +@pytest.fixture(scope="class") +def iree_hip_target_type(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option( + request, "iree_hip_target", "iree_hip_target_type" + ) + + @pytest.fixture(scope="class") def get_model_path(request: FixtureRequest): model_path = {} diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 484d094e3..f1ffa058d 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -16,6 +16,7 @@ # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +from ..models.llama.sharding import shard_theta from ..models.mixtral.mixtral import * from ..models.grok.grok import * @@ -51,6 +52,18 @@ def main(): help="Enables strictness during export", action="store_true", ) + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch_sdpa"], + ) + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="How many devices are involved for tensor parallel sharding.", + ) args = cli.parse(parser) dataset_type = cli.get_input_data_files(args) @@ -59,6 +72,8 @@ def main(): hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) + if args.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, llama_config) llama_config.use_hf = False llama_config.static_tables = False # Rely on the compiler for hoisting tables. llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_tests.py b/sharktank/tests/models/llama/benchmark_amdgpu_tests.py new file mode 100644 index 000000000..6c359a743 --- /dev/null +++ b/sharktank/tests/models/llama/benchmark_amdgpu_tests.py @@ -0,0 +1,919 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from datetime import datetime +import os +import sys +import unittest +import pytest +import subprocess +from pathlib import Path +from typing import List + +longrun = pytest.mark.skipif("not config.getoption('longrun')") +is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") + + +class ExportMlirException(Exception): + """SHARK-Platform export MLIR exception that preserves the command line and error output.""" + + def __init__(self, process: subprocess.CompletedProcess, cwd: str): + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) + + super().__init__( + f"Error invoking export_paged_llama_v1.py\n" + f"Error code: {process.returncode}\n" + f"Stderr diagnostics:\n{errs}\n\n" + f"Invoked with:\n" + f" cd {cwd} && {process.args}\n\n" + ) + + +class IreeCompileException(Exception): + """Compiler exception that preserves the command line and error output.""" + + def __init__(self, process: subprocess.CompletedProcess, cwd: str): + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) + + super().__init__( + f"Error invoking iree-compile\n" + f"Error code: {process.returncode}\n" + f"Stderr diagnostics:\n{errs}\n\n" + f"Invoked with:\n" + f" cd {cwd} && {process.args}\n\n" + ) + + +class IreeBenchmarkException(Exception): + """Runtime exception that preserves the command line and error output.""" + + def __init__( + self, process: subprocess.CompletedProcess, cwd: str, compile_cmd: str + ): + # iree-run-module sends output to both stdout and stderr + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) + try: + outs = process.stdout.decode("utf-8") + except: + outs = str(process.stdout) + + super().__init__( + f"Error invoking iree-benchmark-module\n" + f"Error code: {process.returncode}\n" + f"Stderr diagnostics:\n{errs}\n" + f"Stdout diagnostics:\n{outs}\n" + f"Compiled with:\n" + f" cd {cwd} && {compile_cmd}\n\n" + f"Run with:\n" + f" cd {cwd} && {process.args}\n\n" + ) + + +@pytest.mark.usefixtures("iree_hip_target_type") +class BaseBenchmarkTest(unittest.TestCase): + directory_created = False + current_date = datetime.now() + dir_path_suffix = current_date.strftime("%Y-%m-%d") + cur_dir = os.path.dirname(os.path.abspath(__file__)) + models_dir = os.path.dirname(cur_dir) + tests_dir = os.path.dirname(models_dir) + sharktank_dir = os.path.dirname(tests_dir) + repo_root = os.path.dirname(sharktank_dir) + dir_path = Path(repo_root + "/" + dir_path_suffix) + + @classmethod + def setUpClass(cls): + """This method will be run once per class to create the directory.""" + if not cls.directory_created: + if not os.path.exists(cls.dir_path): + os.makedirs(cls.dir_path) + cls.directory_created = True + + def setUp(self): + self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0") + + def create_file(self, *, suffix, prefix): + file_path = Path(prefix).with_suffix(suffix) + f = open(file_path, "w") + return file_path + + def get_export_cmd( + self, + *, + attention_kernel: str, + tensor_parallelism_size: int, + irpa_path: str, + output_mlir_path: str, + output_json_path: str, + ): + export_args = [ + "python3", + "-m", + "sharktank.examples.export_paged_llm_v1", + "--irpa-file", + irpa_path, + "--output-mlir", + output_mlir_path, + "--output-config", + output_json_path, + ] + if attention_kernel == "decomposed": + export_args.append("--attention-kernel") + export_args.append(attention_kernel) + elif attention_kernel == "torch_sdpa": + raise NotImplementedError( + "attention_kernel torch_sdpa not yet plumbed through" + ) + if tensor_parallelism_size: + export_args.append("--tensor-parallelism-size") + export_args.append(str(tensor_parallelism_size)) + + cmd = subprocess.list2cmdline(export_args) + return cmd + + def get_compile_cmd( + self, *, output_mlir_path: str, output_vmfb_path: str, args: [str] + ): + compile_args = ["iree-compile", output_mlir_path] + compile_args += args + compile_args += ["-o", output_vmfb_path] + cmd = subprocess.list2cmdline(compile_args) + return cmd + + def export_mlir( + self, + *, + attention_kernel: str, + tensor_parallelism_size: int, + irpa_path: str, + output_mlir_path: str, + output_json_path: str, + cwd: str | Path, + ): + """Runs export_paged_llm_v1.py and exports an MLIR file. + Args: + irpa_path: Path to the model irpa file. + output_mlir_path: Path to the file to save the exported file. + output_json_path: Path to the file to save the config json file. + """ + cmd = self.get_export_cmd( + attention_kernel=attention_kernel, + tensor_parallelism_size=tensor_parallelism_size, + irpa_path=irpa_path, + output_mlir_path=output_mlir_path, + output_json_path=output_json_path, + ) + logging.getLogger().info(f"Launching export command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise ExportMlirException(proc, cwd) + + def iree_compile( + self, + *, + mlir_path: str, + output_vmfb_path: str, + args: List[str], + cwd: str | Path, + ): + """Compiles an input MLIR file to an output .vmfb file. + This assumes that the `iree-compile` command is available (usually via PATH). + Args: + mlir_path: Path to the input MLIR file. + output_vmfb_path: Path for the output .vmfb file. The directory must already exist. + args: List of arguments to pass to `iree-compile`. + cwd: current working directory + Raises Exception if compilation fails for some reason. + """ + cmd = self.get_compile_cmd( + output_mlir_path=mlir_path, + output_vmfb_path=output_vmfb_path, + args=args, + ) + logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise IreeCompileException(proc, cwd) + + def iree_benchmark_module( + self, + *, + hip_device_id: str, + vmfb_name: str, + irpa_path: str, + args: List[str], + cwd: str | Path, + ): + """Runs a compiled program with the given args using `iree-benchmark-module`. + This assumes that the `iree-benchmark-module` command is available (usually via PATH). + Args: + vmfb_name: Name of the .vmfb file (relative to `cwd`). + args: List of arguments to pass to `iree-benchmark-module`. + cwd: Working directory to run the command within. (either string or Path works) + compile_cmd: Command used to compile the program, for inclusion in error messages. + Raises Exception if running fails for some reason. + """ + benchmark_args = [ + f"ROCR_VISIBLE_DEVICES={hip_device_id}", + "iree-benchmark-module", + f"--device=hip://{hip_device_id}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={vmfb_name}", + f"--parameters=model={irpa_path}", + ] + benchmark_args += args + cmd = subprocess.list2cmdline(benchmark_args) + logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise IreeBenchmarkException(proc, cwd, cmd) + + +class BenchmarkLlama3_1_8B(BaseBenchmarkTest): + def setUp(self): + super().setUp() + # TODO: add numpy files to Azure and download from it + artifacts_dir = Path("/data/extra/models/llama3.1_8B") + self.irpa_path = artifacts_dir / "llama8b_f16.irpa" + self.irpa_path_fp8 = artifacts_dir / "llama8b_fp8.irpa" + self.tensor_parallelism_size = None + self.dir_path_8b = self.dir_path / "llama-8b" + self.temp_dir_8b = Path(self.dir_path_8b) + self.temp_dir_8b.mkdir(parents=True, exist_ok=True) + self.iree_compile_args = [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={self.iree_hip_target_type}", + ] + self.prefill_args_f16 = artifacts_dir / "prefill_args" + self.decode_args_f16 = artifacts_dir / "decode_args" + self.prefill_args_fp8 = artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = artifacts_dir / "decode_args_fp8" + self.iree_run_prefill_args = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_f16}/tokens.npy", + f"--input=@{self.prefill_args_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_f16}/tokens.npy", + f"--input=@{self.decode_args_f16}/seq_lens.npy", + f"--input=@{self.decode_args_f16}/start_positions.npy", + f"--input=@{self.decode_args_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_args_fp8 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_fp8}/tokens.npy", + f"--input=@{self.prefill_args_fp8}/seq_lens.npy", + f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy", + f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args_fp8 = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_fp8}/tokens.npy", + f"--input=@{self.decode_args_fp8}/seq_lens.npy", + f"--input=@{self.decode_args_fp8}/start_positions.npy", + f"--input=@{self.decode_args_fp8}/seq_block_ids.npy", + f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + + @longrun + @is_mi300x + def testBenchmark8B_f16_Decomposed(self): + output_file_name = self.dir_path_8b / "f16_decomposed" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) + def testBenchmark8B_f16_Non_Decomposed(self): + output_file_name = self.dir_path_8b / "f16_torch_sdpa" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="torch_sdpa", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="8B fp8 irpa path not stored yet", strict=True) + def testBenchmark8B_fp8_Decomposed(self): + output_file_name = self.dir_path_8b / "fp8_decomposed" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path_fp8, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=self.iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) + def testBenchmark8B_fp8_Non_Decomposed(self): + output_file_name = self.dir_path_8b / "fp8_torch_sdpa" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="torch_sdpa", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path_fp8, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=self.iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args_fp8, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args_fp8, + cwd=self.repo_root, + ) + + +class BenchmarkLlama3_1_70B(BaseBenchmarkTest): + def setUp(self): + super().setUp() + # TODO: add numpy files to Azure and download from it + artifacts_dir = Path("/data/extra/models/llama3.1_70B") + self.irpa_path = artifacts_dir / "llama70b_f16.irpa" + self.irpa_path_fp8 = artifacts_dir / "llama70b_fp8.irpa" + self.tensor_parallelism_size = 1 + self.dir_path_70b = self.dir_path / "llama-70b" + self.temp_dir_70b = Path(self.dir_path_70b) + self.temp_dir_70b.mkdir(parents=True, exist_ok=True) + self.iree_compile_args = [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={self.iree_hip_target_type}", + ] + self.prefill_args_f16 = artifacts_dir / "prefill_args" + self.decode_args_f16 = artifacts_dir / "decode_args" + self.prefill_args_fp8 = artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = artifacts_dir / "decode_args_fp8" + self.iree_run_prefill_args = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_f16}/tokens.npy", + f"--input=@{self.prefill_args_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_f16}/tokens.npy", + f"--input=@{self.decode_args_f16}/seq_lens.npy", + f"--input=@{self.decode_args_f16}/start_positions.npy", + f"--input=@{self.decode_args_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_args_fp8 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_fp8}/tokens.npy", + f"--input=@{self.prefill_args_fp8}/seq_lens.npy", + f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy", + f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args_fp8 = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_fp8}/tokens.npy", + f"--input=@{self.decode_args_fp8}/seq_lens.npy", + f"--input=@{self.decode_args_fp8}/start_positions.npy", + f"--input=@{self.decode_args_fp8}/seq_block_ids.npy", + f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="70b f16 irpa path not stored yet", strict=True) + def testBenchmark70B_f16_Decomposed(self): + output_file_name = self.dir_path_70b / "f16_decomposed" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) + def testBenchmark70B_f16_Non_Decomposed(self): + output_file_name = self.dir_path_70b / "f16_torch_sdpa" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="torch_sdpa", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="70B fp8 irpa path not stored yet", strict=True) + def testBenchmark70B_fp8_Decomposed(self): + output_file_name = self.dir_path_70b / "fp8_decomposed" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path_fp8, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=self.iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) + def testBenchmark70B_fp8_Non_Decomposed(self): + output_file_name = self.dir_path_70b / "fp8_torch_sdpa" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="torch_sdpa", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path_fp8, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=self.iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args_fp8, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args_fp8, + cwd=self.repo_root, + ) + + +class BenchmarkLlama3_1_405B(BaseBenchmarkTest): + def setUp(self): + super().setUp() + # TODO: add numpy files to Azure and download from it + artifacts_dir = Path("/data/extra/models/llama3.1_405B") + self.irpa_path = artifacts_dir / "llama405b_f16.irpa" + self.irpa_path_fp8 = artifacts_dir / "llama405b_fp8.irpa" + self.tensor_parallelism_size = 8 + self.dir_path_405b = self.dir_path / "llama-405b" + self.temp_dir_405b = Path(self.dir_path_405b) + self.temp_dir_405b.mkdir(parents=True, exist_ok=True) + self.iree_compile_args = [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={self.iree_hip_target_type}", + ] + self.prefill_args_f16 = artifacts_dir / "prefill_args" + self.decode_args_f16 = artifacts_dir / "decode_args" + self.prefill_args_fp8 = artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = artifacts_dir / "decode_args_fp8" + self.iree_run_prefill_args = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_f16}/tokens.npy", + f"--input=@{self.prefill_args_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_f16}/tokens.npy", + f"--input=@{self.decode_args_f16}/seq_lens.npy", + f"--input=@{self.decode_args_f16}/start_positions.npy", + f"--input=@{self.decode_args_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_args_fp8 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_fp8}/tokens.npy", + f"--input=@{self.prefill_args_fp8}/seq_lens.npy", + f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy", + f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args_fp8 = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_fp8}/tokens.npy", + f"--input=@{self.decode_args_fp8}/seq_lens.npy", + f"--input=@{self.decode_args_fp8}/start_positions.npy", + f"--input=@{self.decode_args_fp8}/seq_block_ids.npy", + f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="405B f16 irpa path not stored yet", strict=True) + def testBenchmark405B_f16_Decomposed(self): + output_file_name = self.dir_path_405b / "f16_decomposed" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) + def testBenchmark405B_f16_Non_Decomposed(self): + output_file_name = self.dir_path_405b / "f16_torch_sdpa" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="torch_sdpa", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="405B fp8 irpa path not stored yet", strict=True) + def testBenchmark405B_fp8_Decomposed(self): + output_file_name = self.dir_path_405b / "fp8_decomposed" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path_fp8, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=self.iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @longrun + @is_mi300x + @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) + def testBenchmark405B_fp8_Non_Decomposed(self): + output_file_name = self.dir_path_405b / "fp8_torch_sdpa" + output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) + output_json = self.create_file(suffix=".json", prefix=output_file_name) + output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) + self.export_mlir( + attention_kernel="torch_sdpa", + tensor_parallelism_size=self.tensor_parallelism_size, + irpa_path=self.irpa_path_fp8, + output_mlir_path=output_mlir, + output_json_path=output_json, + cwd=self.repo_root, + ) + iree_compile_args = self.iree_compile_args + [ + f"--iree-hal-dump-executable-files-to={output_file_name}/files" + ] + self.iree_compile( + mlir_path=output_mlir, + output_vmfb_path=output_vmfb, + args=self.iree_compile_args, + cwd=self.repo_root, + ) + # benchmark prefill + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args_fp8, + cwd=self.repo_root, + ) + # benchmark decode + self.iree_benchmark_module( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args_fp8, + cwd=self.repo_root, + ) + + +if __name__ == "__main__": + unittest.main() From 4a49a847c873f4940bb0786283c3ddf25c4c6da3 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 22 Oct 2024 18:59:18 -0700 Subject: [PATCH 3/3] [shortfin] Implements RandomGenerator and fill_randn. (#310) For this, I opted to only support in-place variants since in shortfin, we generally operate directly on arrays vs forcing copies or having a lot of options in constructors for managing the device assignment. Otherwise, the definition of randn matches torch.randn. Fixes #261. --- shortfin/python/array_host_ops.cc | 73 ++++++++++++++++++++++++++++ shortfin/python/shortfin/array.py | 4 ++ shortfin/tests/api/array_ops_test.py | 58 ++++++++++++++++++++++ 3 files changed, 135 insertions(+) diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index b5b4aeeea..8c4af0070 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -8,6 +8,7 @@ #include "./utils.h" #include "shortfin/array/api.h" #include "shortfin/support/logging.h" +#include "xtensor/xrandom.hpp" #include "xtensor/xsort.hpp" #include "xtl/xhalf_float.hpp" @@ -37,13 +38,59 @@ Implemented for dtypes: float16, float32. A device_array of dtype=int64, allocated on the host and not visible to the device. )"; +static const char DOCSTRING_FILL_RANDN[] = + R"(Fills an array with numbers sampled from the standard ormal distribution. + +Values are samples with a mean of 0 and standard deviation of 1. + +This operates like torch.randn but only supports in place fills to an existing +array, deriving shape and dtype from the output array. + +Args: + out: Output array to fill. + generator: Uses an explicit generator. If not specified, uses a global + default. +)"; + +static const char DOCSTRING_RANDOM_GENERATOR[] = + R"(Returns an object for generating random numbers. + + Every instance is self contained and does not share state with others. + + Args: + seed: Optional seed for the generator. Not setting a seed will cause an + implementation defined value to be used, which may in fact be a completely + fixed number. + )"; + } // namespace +struct PyRandomGenerator { + public: + using SeedType = xt::random::default_engine_type::result_type; + PyRandomGenerator(std::optional seed) { + if (seed) SetSeed(*seed); + } + + static PyRandomGenerator &get_default() { + static PyRandomGenerator default_generator(std::nullopt); + return default_generator; + } + + void SetSeed(SeedType seed) { engine().seed(seed); } + + xt::random::default_engine_type &engine() { return engine_; } + + private: + xt::random::default_engine_type engine_; +}; + #define SF_UNARY_COMPUTE_CASE(dtype_name, cpp_type) \ case DType::dtype_name(): \ return compute.template operator()() void BindArrayHostOps(py::module_ &m) { + // Simple op definitions. m.def( "argmax", [](device_array &input, int axis, std::optional out, @@ -84,6 +131,32 @@ void BindArrayHostOps(py::module_ &m) { py::arg("input"), py::arg("axis") = -1, py::arg("out") = py::none(), py::kw_only(), py::arg("keepdims") = false, py::arg("device_visible") = false, DOCSTRING_ARGMAX); + + // Random number generation. + py::class_(m, "RandomGenerator") + .def(py::init>(), + py::arg("seed") = py::none(), DOCSTRING_RANDOM_GENERATOR); + m.def( + "fill_randn", + [](device_array out, std::optional gen) { + if (!gen) gen = &PyRandomGenerator::get_default(); + auto compute = [&]() { + auto result = xt::random::randn(out.shape_container(), /*mean=*/0.0, + /*std_dev=*/1.0, (*gen)->engine()); + auto out_t = out.map_xtensor_w(); + *out_t = result; + }; + + switch (out.dtype()) { + SF_UNARY_COMPUTE_CASE(float16, half_float::half); + SF_UNARY_COMPUTE_CASE(float32, float); + default: + throw std::invalid_argument( + fmt::format("Unsupported dtype({}) for operator randn", + out.dtype().name())); + } + }, + py::arg("out"), py::arg("generator") = py::none(), DOCSTRING_FILL_RANDN); } } // namespace shortfin::python diff --git a/shortfin/python/shortfin/array.py b/shortfin/python/shortfin/array.py index afc137b72..097980b64 100644 --- a/shortfin/python/shortfin/array.py +++ b/shortfin/python/shortfin/array.py @@ -42,6 +42,8 @@ # Ops. argmax = _sfl.array.argmax +fill_randn = _sfl.array.fill_randn +RandomGenerator = _sfl.array.RandomGenerator __all__ = [ # DType aliases. @@ -78,4 +80,6 @@ "DType", # Ops. "argmax", + "fill_randn", + "RandomGenerator", ] diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py index 42bbeeafe..e6d1af3cd 100644 --- a/shortfin/tests/api/array_ops_test.py +++ b/shortfin/tests/api/array_ops_test.py @@ -106,3 +106,61 @@ def test_argmax_dtypes(device, dtype): # some of these. src = sfnp.device_array(device, [4, 16, 128], dtype=dtype) sfnp.argmax(src) + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.float16, + sfnp.float32, + ], +) +def test_fill_randn_default_generator(device, dtype): + out1 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out1.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out1) + out2 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out2.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out2) + + with out1.map(read=True) as m1, out2.map(read=True) as m2: + # The default generator should populate two different arrays. + contents1 = bytes(m1) + contents2 = bytes(m2) + assert contents1 != contents2 + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.float16, + sfnp.float32, + ], +) +def test_fill_randn_explicit_generator(device, dtype): + gen1 = sfnp.RandomGenerator(42) + gen2 = sfnp.RandomGenerator(42) + out1 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out1.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out1, generator=gen1) + out2 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out2.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out2, generator=gen2) + zero = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with zero.map(write=True) as m: + m.fill(bytes(1)) + + with out1.map(read=True) as m1, out2.map(read=True) as m2, zero.map( + read=True + ) as mz: + # Using explicit generators with the same seed should produce the + # same distributions. + contents1 = bytes(m1) + contents2 = bytes(m2) + assert contents1 == contents2 + # And not be zero. + assert contents1 != bytes(mz)