diff --git a/tuner/examples/dispatch/.gitignore b/tuner/examples/dispatch/.gitignore deleted file mode 100644 index 9fb2fe16a..000000000 --- a/tuner/examples/dispatch/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Test files/dirs recommended by README.md. -dump/ -benchmark.mlir diff --git a/tuner/examples/dispatch/README.md b/tuner/examples/dispatch/README.md deleted file mode 100644 index 70c46e08a..000000000 --- a/tuner/examples/dispatch/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# Dispatch Tuner - -Allows to tune a single dispatch in isolation. - -## Environments -Follow instructions in [`/tuner/README.md`](../README.md) - -## Running the Dispatch Tuner - -### Generate a benchmark file -Use the usual `iree-compile` command for your dispatch and add -`--iree-hal-dump-executable-files-to=dump`. For example: -```shell -iree-compile mmt.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump -o /dev/null -``` - -Next, copy the `*_benchmark.mlir` file to some temporary directory of choice. -This will be the input to the dispatch tuner. - -### Recommended Trial Run -For an initial trial to test the tuning loop, use: -```shell -python -m examples.dispatch benchmark.mlir --num-candidates=20 -``` - -### Dry Run Test -To perform a dry run (no GPU required), use: -```shell -python -m examples.dispatch benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run -``` - -### Basic Usage -```shell -python -m examples.dispatch benchmark.mlir -``` diff --git a/tuner/examples/dispatch/__init__.py b/tuner/examples/dispatch/__init__.py deleted file mode 100644 index a85ba359d..000000000 --- a/tuner/examples/dispatch/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/dispatch/__main__.py b/tuner/examples/dispatch/__main__.py deleted file mode 100644 index 9fb86fd9f..000000000 --- a/tuner/examples/dispatch/__main__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from . import dispatch_tuner - -dispatch_tuner.main() diff --git a/tuner/examples/dispatch/compile_dispatch.sh b/tuner/examples/dispatch/compile_dispatch.sh deleted file mode 100755 index 0b01ac991..000000000 --- a/tuner/examples/dispatch/compile_dispatch.sh +++ /dev/null @@ -1,18 +0,0 @@ -#! /usr/bin/env bash - -set -eou pipefail - -readonly INPUT="$1" -readonly DIR="$(dirname "$INPUT")" -readonly BASENAME="$(basename "$INPUT" .mlir)" -readonly OUT="${DIR}/compiled/${BASENAME}.vmfb" - -iree-compile "$INPUT" -o "$OUT" \ - --compile-from=executable-sources 2>/dev/null || (mv "$INPUT" "$DIR/failed" && exit 1) - -iree-dump-module "$OUT" | grep -q 'rocm-hsaco-fb' || (mv "$INPUT" "$DIR/failed" && rm -f "$OUT" && exit 1) -if [ -f "${DIR}/${BASENAME}_config.mlir" ]; then - cat "${DIR}/../config_prolog.mlir" "${DIR}/${BASENAME}_config.mlir" "${DIR}/../config_epilog.mlir" > "${DIR}/specs/${BASENAME}_spec.mlir" -fi - -echo "Compiling ${INPUT}: success" diff --git a/tuner/examples/dispatch/config_epilog.mlir b/tuner/examples/dispatch/config_epilog.mlir deleted file mode 100644 index c15a30502..000000000 --- a/tuner/examples/dispatch/config_epilog.mlir +++ /dev/null @@ -1,12 +0,0 @@ - -//===----------------------------------------------------------------------===// -// Entry point -//===----------------------------------------------------------------------===// - - transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.consumed}) { - transform.foreach_match in %variant_op - , @match_op -> @apply_op_config - : (!transform.any_op) -> (!transform.any_op) - transform.yield - } -} //// module diff --git a/tuner/examples/dispatch/config_prolog.mlir b/tuner/examples/dispatch/config_prolog.mlir deleted file mode 100644 index 377ac3f8f..000000000 --- a/tuner/examples/dispatch/config_prolog.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// Transform dialect specification for attention on MI300 with MFMA. -module attributes { transform.with_named_sequence } { -//===----------------------------------------------------------------------===// -// Matmul tuning -//===----------------------------------------------------------------------===// - - transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { - transform.match.operation_name %root ["linalg.generic"] : !transform.any_op - // transform.print %root {name = "Generic"} : !transform.any_op - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { - ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): - %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"]} - ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { - ^bb0(%in: f16, %in_0: f16, %acc: f32): - %8 = arith.extf %in : f16 to f32 - %9 = arith.extf %in_0 : f16 to f32 - %10 = arith.mulf %8, %9 : f32 - %11 = arith.addf %acc, %10 : f32 - linalg.yield %11 : f32 - } -> tensor - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - transform.yield %root : !transform.any_op - } - - transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, %config: !transform.any_param {transform.readonly}) { - transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param - // transform.print %op {name = "Applied"} : !transform.any_op - transform.yield - } diff --git a/tuner/examples/dispatch/dispatch_tuner.py b/tuner/examples/dispatch/dispatch_tuner.py deleted file mode 100644 index 0f5b54979..000000000 --- a/tuner/examples/dispatch/dispatch_tuner.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Sample Usage: - -python -m examples.dispatch benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64 - - -Recommended Trial Run: - -python -m examples.dispatch benchmark.mlir --num-candidates=10 - - -Dry Run Test (no gpu required): - -python -m examples.dispatch benchmark.mlir --num-candidates=64 --dry-run - -""" - -from tuner import libtuner -from pathlib import Path, PurePath -import os - - -class DispatchTuner(libtuner.TuningClient): - def get_dispatch_compile_timeout_s(self) -> int: - return 10 - - def get_dispatch_compile_command( - self, candidate_tracker: libtuner.CandidateTracker - ) -> list[str]: - assert candidate_tracker.dispatch_mlir_path is not None - mlir_path: Path = candidate_tracker.dispatch_mlir_path - script_dir = Path(__file__).resolve().parent - command = [ - (script_dir / "compile_dispatch.sh").as_posix(), - mlir_path.as_posix(), - ] - return command - - def get_dispatch_benchmark_timeout_s(self) -> int: - return 15 - - def get_dispatch_benchmark_command( - self, - candidate_tracker: libtuner.CandidateTracker, - ) -> list[str]: - compiled_vmfb_path = candidate_tracker.compiled_dispatch_path - assert compiled_vmfb_path is not None - - command = [ - "iree-benchmark-module", - f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", - f"--module={compiled_vmfb_path.resolve()}", - "--batch_size=1000", - "--benchmark_repetitions=3", - "--benchmark_format=json", - ] - - return command - - def get_model_compile_timeout_s(self) -> int: - return 0 - - def get_model_compile_command( - self, candidate_tracker: libtuner.CandidateTracker - ) -> list[str]: - return [] - - def get_model_benchmark_timeout_s(self) -> int: - return 0 - - def get_model_benchmark_command( - self, candidate_tracker: libtuner.CandidateTracker - ) -> list[str]: - return [] - - def get_iree_compile_flags(self) -> list[str]: - return [] - - def get_iree_benchmark_module_flags(self) -> list[str]: - return [] - - def get_benchmark_timeout_s(self) -> int: - return 0 - - -def main(): - args = libtuner.parse_arguments() - path_config = libtuner.PathConfig() - # These will not be used, so always default to the empty config in the script dir. - script_dir = Path(__file__).resolve().parent - path_config.global_config_prolog_mlir = ( - script_dir / path_config.global_config_prolog_mlir - ) - path_config.global_config_epilog_mlir = ( - script_dir / path_config.global_config_epilog_mlir - ) - path_config.base_dir.mkdir(parents=True, exist_ok=True) - path_config.output_unilog.touch() - candidate_trackers: list[libtuner.CandidateTracker] = [] - dispatch_tuner = DispatchTuner() - stop_after_phase: str = args.stop_after - - print("Setup logging") - libtuner.setup_logging(args, path_config) - print(path_config.run_log, end="\n\n") - - if not args.dry_run: - print("Validating devices") - libtuner.validate_devices(args.devices) - print("Validation successful!\n") - - print("Generating candidates...") - candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) - print(f"Stored candidates in {path_config.candidates_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: - return - - print("Compiling candidates...") - compiled_candidates = libtuner.compile_dispatches( - args, path_config, candidates, candidate_trackers, dispatch_tuner - ) - print(f"Compiled files are stored in {path_config.compiled_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: - return - - print("Benchmarking compiled candidates...") - top_candidates = libtuner.benchmark_dispatches( - args, path_config, compiled_candidates, candidate_trackers, dispatch_tuner - ) - print(f"\nStored results in {path_config.output_unilog.resolve()}\n") - if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: - return - - libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) - print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") - - print("Check the detailed execution logs in:") - print(path_config.run_log.resolve()) - - for candidate in candidate_trackers: - libtuner.logging.debug(candidate) diff --git a/tuner/examples/dispatch/mmt.mlir b/tuner/examples/dispatch/mmt.mlir deleted file mode 100644 index b9d6c5f4c..000000000 --- a/tuner/examples/dispatch/mmt.mlir +++ /dev/null @@ -1,11 +0,0 @@ -!matA_0 = tensor<2048x1280xf16> -!matB_0 = tensor<10240x1280xf16> -!matC_0 = tensor<2048x10240xf32> - -func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 { - %cst = arith.constant 0.000000e+00 : f16 - %5 = tensor.empty() : !matC_0 - %6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0 - %8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 - return %8 : !matC_0 -} diff --git a/tuner/examples/punet/.gitignore b/tuner/examples/punet/.gitignore deleted file mode 100644 index fae904ffb..000000000 --- a/tuner/examples/punet/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Test files/dirs recommended by README.md. -dump-mmt -test-benchmark.mlir diff --git a/tuner/examples/punet/README.md b/tuner/examples/punet/README.md deleted file mode 100644 index 777d1c194..000000000 --- a/tuner/examples/punet/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# Punet Tuner - -## Environments -Follow instructions in [`/tuner/README.md`](../README.md) - -## Shell Scripts - -The required shell scripts can be downloaded from: -[sdxl-scripts](https://github.com/nod-ai/sdxl-scripts). - -These scripts include: -1. `compile-punet-base.sh` - Used for compiling model candidates. -2. `compile_candidate.sh` - Used for compiling dispatch candidates. -3. `punet.sh` - Invoked by `compile_candidate.sh`. - -Add the parent directories of these scripts to your `PATH` environment variable, -so that they can be picked up by `punet_autotune.py`. - -## Running the Tuner - -### [Optional] Generate a tunable mlir -Use -[`punet.sh`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/punet.sh) -to compile the sample matmul `mmt.mlir` (can also find here: -[`mmt_unet.mlir`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/mmt_unet.mlir)): -```shell -punet.sh mmt.mlir -o mmt.vmfb --iree-hal-dump-executable-files-to=dump-mmt -cp ./dump-mmt/module_main_0_dispatch_0_rocm_hsaco_fb_benchmark.mlir test-benchmark.mlir -``` - -### Recommended Trial Run -For an initial trial to test the tuning loop, use: -```shell -python -m examples.punet test-benchmark.mlir --num-candidates=10 -``` - -### Dry Run Test -To perform a dry run (no GPU required), use: -```shell -python -m examples.punet test-benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run -``` - -### Basic Usage -```shell -python -m examples.punet test-benchmark.mlir -``` diff --git a/tuner/examples/punet/__init__.py b/tuner/examples/punet/__init__.py deleted file mode 100644 index a85ba359d..000000000 --- a/tuner/examples/punet/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/punet/__main__.py b/tuner/examples/punet/__main__.py deleted file mode 100644 index ca092d502..000000000 --- a/tuner/examples/punet/__main__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from . import punet_autotune - -punet_autotune.main() diff --git a/tuner/examples/punet/mmt.mlir b/tuner/examples/punet/mmt.mlir deleted file mode 100644 index b9d6c5f4c..000000000 --- a/tuner/examples/punet/mmt.mlir +++ /dev/null @@ -1,11 +0,0 @@ -!matA_0 = tensor<2048x1280xf16> -!matB_0 = tensor<10240x1280xf16> -!matC_0 = tensor<2048x10240xf32> - -func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 { - %cst = arith.constant 0.000000e+00 : f16 - %5 = tensor.empty() : !matC_0 - %6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0 - %8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 - return %8 : !matC_0 -} diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py deleted file mode 100644 index 2bfdb4d24..000000000 --- a/tuner/examples/punet/punet_autotune.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Sample Usage: - -python -m examples.punet benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64 - - -Recommended Trial Run: - -python -m examples.punet benchmark.mlir --num-candidates=1 - - -Dry Run Test (no gpu requried): - -python -m examples.punet benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run - -""" - -from tuner import libtuner -from pathlib import Path - - -class PunetClient(libtuner.TuningClient): - def get_dispatch_compile_timeout_s(self) -> int: - return 4 - - def get_dispatch_compile_command( - self, candidate_tracker: libtuner.CandidateTracker - ) -> list[str]: - mlir_path = candidate_tracker.dispatch_mlir_path - assert mlir_path is not None - command = [ - "compile_candidate.sh", - mlir_path.as_posix(), - ] - return command - - def get_dispatch_benchmark_timeout_s(self) -> int: - return 15 - - def get_dispatch_benchmark_command( - self, - candidate_tracker: libtuner.CandidateTracker, - ) -> list[str]: - compiled_vmfb_path = candidate_tracker.compiled_dispatch_path - assert compiled_vmfb_path is not None - - command = [ - "iree-benchmark-module", - f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", - f"--module={compiled_vmfb_path.resolve()}", - "--hip_use_streams=true", - "--hip_allow_inline_execution=true", - "--batch_size=1000", - "--benchmark_repetitions=3", - "--benchmark_format=json", - ] - - return command - - def get_model_compile_timeout_s(self) -> int: - return 300 - - def get_model_compile_command( - self, candidate_tracker: libtuner.CandidateTracker - ) -> list[str]: - mlir_spec_path = candidate_tracker.spec_path - assert mlir_spec_path is not None - target_dir = mlir_spec_path.resolve().parent.parent.parent - output_name = f"unet_candidate_{candidate_tracker.candidate_id}.vmfb" - command = [ - "compile-punet-base.sh", - "iree-compile", - "gfx942", - f"{mlir_spec_path.resolve()}", - "./punet.mlir", - "-o", - (target_dir / output_name).as_posix(), - ] - return command - - def get_model_benchmark_timeout_s(self) -> int: - return 180 - - def get_model_benchmark_command( - self, candidate_tracker: libtuner.CandidateTracker - ) -> list[str]: - unet_candidate_path = candidate_tracker.compiled_model_path - assert unet_candidate_path is not None - - command = [ - "iree-benchmark-module", - f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", - "--hip_use_streams=true", - "--hip_allow_inline_execution=true", - "--device_allocator=caching", - f"--module={unet_candidate_path.resolve()}", - "--parameters=model=punet.irpa", - "--function=main", - "--input=1x4x128x128xf16", - "--input=1xsi32", - "--input=2x64x2048xf16", - "--input=2x1280xf16", - "--input=2x6xf16", - "--input=1xf16", - "--benchmark_repetitions=5", - "--benchmark_format=json", - ] - return command - - def get_iree_compile_flags(self) -> list[str]: - return [] - - def get_iree_benchmark_module_flags(self) -> list[str]: - return [] - - def get_benchmark_timeout_s(self) -> int: - return 0 - - -def main(): - args = libtuner.parse_arguments() - path_config = libtuner.PathConfig() - path_config.base_dir.mkdir(parents=True, exist_ok=True) - path_config.output_unilog.touch() - candidate_trackers: list[libtuner.CandidateTracker] = [] - punet_client = PunetClient() - stop_after_phase: str = args.stop_after - - print("Setup logging") - libtuner.setup_logging(args, path_config) - print(path_config.run_log, end="\n\n") - - if not args.dry_run: - print("Validating devices") - libtuner.validate_devices(args.devices) - print("Validation successful!\n") - - print("Generating candidates...") - candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) - print(f"Stored candidates in {path_config.candidates_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: - return - - print("Compiling candidates...") - compiled_candidates = libtuner.compile_dispatches( - args, path_config, candidates, candidate_trackers, punet_client - ) - print(f"Compiled files are stored in {path_config.compiled_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: - return - - print("Benchmarking compiled candidates...") - top_candidates = libtuner.benchmark_dispatches( - args, path_config, compiled_candidates, candidate_trackers, punet_client - ) - print(f"Stored results in {path_config.output_unilog}\n") - if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: - return - - print(f"Compiling top model candidates...") - punet_candidates = libtuner.compile_models( - args, path_config, top_candidates, candidate_trackers, punet_client - ) - print(f"Model candidates compiled in {path_config.base_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.compile_models: - return - - print("Benchmarking model candidates...") - libtuner.benchmark_models( - args, path_config, punet_candidates, candidate_trackers, punet_client - ) - print(f"Stored results in {path_config.output_unilog}") - if stop_after_phase == libtuner.ExecutionPhases.benchmark_models: - return - - libtuner.summerize_top_candidates(path_config, candidate_trackers) - print(f"Stored top candidates info in {path_config.result_summary_log}\n") - - libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) - print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") - - print("Check the detailed execution logs in:") - print(path_config.run_log) - - for candidate in candidate_trackers: - libtuner.logging.debug(candidate) - if args.verbose: - print(candidate) diff --git a/tuner/examples/test/README.md b/tuner/examples/test/README.md index 5dfba0da3..47ae7a8fe 100644 --- a/tuner/examples/test/README.md +++ b/tuner/examples/test/README.md @@ -35,5 +35,6 @@ python -m examples.test double_mmt.mlir mmt_benchmark.mlir \ python -m examples.test \ --test_num_dispatch_candidates= \ --test_num_model_candidates= \ - --test_hip_target= \ --num-candidates= + --test_hip_target= \ + --num-candidates= ``` diff --git a/tuner/examples/test/tuner_test.py b/tuner/examples/test/tuner_test.py index 528f03b80..22a0d2f4d 100644 --- a/tuner/examples/test/tuner_test.py +++ b/tuner/examples/test/tuner_test.py @@ -90,7 +90,6 @@ def main(): path_config = libtuner.PathConfig() path_config.base_dir.mkdir(parents=True, exist_ok=True) - path_config.output_unilog.touch() # TODO(Max191): Make candidate_trackers internal to TuningClient. candidate_trackers: list[libtuner.CandidateTracker] = [] stop_after_phase: str = args.stop_after diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 45cb3512a..b6264792e 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -11,21 +11,17 @@ Generate candidates by tweaking op configuration for tuning. It can be invoked in two ways: - 1. From another python script, import and call `tune()` + 1. From another python script, import and call `generate_configs_and_td_specs()` 2. Run this script directly from the command - -Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk - +Usage: python -m tuner.candidate_gen mmt_benchmark.mlir -o spec_dir -l 1024 """ import argparse import logging -import pickle -import re from dataclasses import dataclass -from os import path, makedirs +from pathlib import Path +import subprocess from typing import Optional -from textwrap import indent from abc import abstractmethod from iree.compiler import ir # type: ignore @@ -40,61 +36,6 @@ tune_logger = logging.getLogger("tune") -def apply_configuration( - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, -) -> str: - lowering_config = compilation_info.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - workgroup_sizes = lowering_config.workgroup_tile_sizes - reduction_sizes = lowering_config.reduction_tile_sizes - gpu_pipeline_options = compilation_info.translation_info.configuration[ - GPU_PIPELINE_OPTIONS_KEY - ] - waves_per_eu = compilation_info.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ - WAVES_PER_EU_KEY - ] - tune_logger.info(f"Applying: {compilation_info}") - expr0 = re.compile( - r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" - ) - expr1 = re.compile( - r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," - ) - expr2 = re.compile(r"workgroup = \[([0-9]+)(, ([0-9]+))+\]") - expr3 = re.compile(r"reduction = \[([0-9]+)(, ([0-9]+))+\]") - expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") - expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f"" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, compilation_info.translation_info.workgroup_size))}] subgroup_size = {compilation_info.translation_info.subgroup_size},' - repl2 = f"workgroup = {workgroup_sizes}" - repl3 = f"reduction = {reduction_sizes}" - repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}" - repl5 = f'"amdgpu-waves-per-eu" = {waves_per_eu}' - - new_mlir = "" - for line in template: - if "intrinsic =" in line: - line = re.sub(expr0, repl0, line) - if "LLVMGPUVectorDistribute " in line: - line = re.sub(expr1, repl1, line) - if "workgroup" in line: - line = re.sub(expr2, repl2, line) - if "reduction" in line: - line = re.sub(expr3, repl3, line) - if "gpu_pipeline_options =" in line: - line = re.sub(expr4, repl4, line) - if "amdgpu-waves-per-eu" in line: - line = re.sub(expr5, repl5, line) - new_mlir += line - - return new_mlir - - class DispatchTuner(DispatchParser): # TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod @@ -206,321 +147,6 @@ def get_td_spec( return build_td_spec(ir_module.context, conv_op, compilation_info, func_name) -class MmtTuner(DispatchTuner, MmtParser): - def get_transform_function_mmt( - self, - problem_size: ProblemSize, - functionName: str, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> str: - return f""" - transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant {compilation_info} -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", compilation_info - ), - "// ", - ) - modified += apply_configuration( - template, - compilation_info, - ) - embeddable = indent( - self.get_transform_function_mmt( - problem_size, f"match_op", compilation_info - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def get_td_spec( - self, - ir_module: ir.Module, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> ir.Module: - raise NotImplementedError - - -class ConvTuner(DispatchTuner, ConvParser): - def get_transform_function_conv( - self, - problem_size: ProblemSize, - functionName: str, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> str: - dynamic_batch_input_ty = problem_size.lhs_type - dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() - dynamic_batch_input_ty.shape[0] = -1 - - dynamic_batch_output_ty = problem_size.res_type - dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() - dynamic_batch_output_ty.shape[0] - 1 - - input = f"tensor<{dynamic_batch_input_ty}>" - filter = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{dynamic_batch_output_ty}>" - - return f""" - transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ - ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): - %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} - ins(%lhs, %rhs : {input}, {filter}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant {compilation_info} -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - conv_dims = ConvDimInfo.from_problem_size(problem_size) - modified = indent( - self.get_transform_function_conv( - problem_size, - f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - compilation_info, - ), - "// ", - ) - modified += apply_configuration( - template, - compilation_info, - ) - embeddable = indent( - self.get_transform_function_conv( - problem_size, f"match_op", compilation_info - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def get_td_spec( - self, - ir_module: ir.Module, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> ir.Module: - raise NotImplementedError - - -class ContractionTuner(DispatchTuner, ContractionParser): - def get_transform_function_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - functionName: str, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> str: - lhs_dynamic_batch = problem_size.lhs_type - lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() - lhs_dynamic_batch.shape[0] = -1 - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant {compilation_info} -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", compilation_info - ), - "// ", - ) - modified += apply_configuration( - template, - compilation_info, - ) - - embeddable = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", compilation_info - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - if self.is_broadcast_rhs_mmt(template): - return self.apply_params_broadcast_rhs_mmt( - problem_size, template, compilation_info - ) - - # TODO: Generate transform function. - return MLIRTransformation( - template, - apply_configuration( - template, - compilation_info, - ), - "", - ) - - def get_td_spec( - self, - ir_module: ir.Module, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> ir.Module: - raise NotImplementedError - - -class BatchMmtTuner(DispatchTuner, BatchMmtParser): - def get_transform_function_batch_mmt( - self, - problem_size: ProblemSize, - functionName: str, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> str: - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant {compilation_info} -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - B = problem_size.matmul_size.B - modified = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", compilation_info - ), - "// ", - ) - modified += apply_configuration( - template, - compilation_info, - ) - - embeddable = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_op", compilation_info - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def get_td_spec( - self, - ir_module: ir.Module, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> ir.Module: - raise NotImplementedError - - -class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): - def get_transform_function_batch_matmul( - self, - problem_size: ProblemSize, - tile_dims: str, - functionName: str, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> str: - input0 = f"tensor<{problem_size.lhs_type}>" - input1 = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{problem_size.res_type}>" - - return f""" - transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ - ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): - %13 = linalg.batch_matmul - ins(%lhs, %rhs : {input0}, {input1}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant {compilation_info} -> !transform.any_param - transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_batch_matmul( - problem_size, - self.tile_dims, - f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - compilation_info, - ), - "// ", - ) - modified += apply_configuration( - template, - compilation_info, - ) - - embeddable = indent( - self.get_transform_function_batch_matmul( - problem_size, self.tile_dims, f"match_op", compilation_info - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def get_td_spec( - self, - ir_module: ir.Module, - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> ir.Module: - raise NotImplementedError - - @dataclass class OpWalkResult: was_interrupted: bool = False @@ -563,82 +189,6 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove in favor of using tune_with_td. -def tune( - input: str, # Path to the mlir file to be tuned - output: str = "", # Path to the output directory, auto creates one if not given - limit: int = 4096, # Max candidates to be generated - num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints - lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations - rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations - tile_dims: str = "mnk", # Dimensions for the tile size -): - input_file = str(input) - - if not output: - output = get_default_output_dir() - - # Create the directory if it does not exist - makedirs(str(output), exist_ok=True) - - tune_logger.debug(f"Output directory {output}") - tune_logger.debug(f"Processing {input_file}") - mlir_template = read_input_mlir(input_file) - mlir_text = "".join(mlir_template) - - with ir.Context() as ctx: - tuner_context = TunerContext(ctx, tune_logger) - mlir_module = parse_mlir(mlir_text, tuner_context) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) - assert len(variant_op_list) == 1, "Expect one executable variant op" - variant_op = variant_op_list[0] - # Get the MMA intrinisic intructions supported by the target. - mma_list = iree_codegen.query_mma_intrinsics(variant_op) - - dispatch_tuner = walk_result.dispatch_tuner - assert dispatch_tuner, "No suitable dispatch tuner found" - problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate( - generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) - ): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - # TODO: Fix pickling for ir types. - # with open(path.join(output, "configs.pkl"), "wb") as file: - # pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") - - def generate_configs_and_td_specs( input_module: ir.Module, # Path to the mlir file to be tuned tuner_context: TunerContext, @@ -684,6 +234,98 @@ def generate_configs_and_td_specs( return config_specs +@dataclass +class RunPack: + command: list[str] + check: bool = True + timeout_seconds: Optional[int] = None + + +@dataclass +class RunResult: + process_res: Optional[subprocess.CompletedProcess] + is_timeout: bool + + +def run_command(run_pack: RunPack) -> RunResult: + command = run_pack.command + check = run_pack.check + timeout_seconds = run_pack.timeout_seconds + + result = None + is_timeout = False + try: + # Convert the command list to a command string for logging + command_str = " ".join(command) + logging.debug(f"Run: {command_str}") + + # Add timeout to subprocess.run call + result = subprocess.run( + command, + check=check, + capture_output=True, + text=True, + timeout=timeout_seconds, + ) + + if result.stdout: + logging.debug(f"stdout: {result.stdout}") + if result.stderr: + logging.debug(f"stderr: {result.stderr}") + except subprocess.TimeoutExpired as e: + logging.warning( + f"Command '{command_str}' timed out after {timeout_seconds} seconds." + ) + is_timeout = True + except subprocess.CalledProcessError as e: + print(e.output) + logging.error( + f"Command '{command_str}' returned non-zero exit status {e.returncode}." + ) + logging.error(f"Command '{command_str}' failed with error: {e.stderr}") + if check: + raise + except KeyboardInterrupt: + print("Ctrl+C detected, terminating child processes...") + + return RunResult(result, is_timeout) + + +# The `strip_root_op_attr` and `strip_compilation_info` functions are used for +# getting consistent inputs to the compilation step in tuning. Inputs may come +# in with lowering configs, translation info, and root_op attrs when the input +# is a benchmark, but not when the input is a source MLIR file. Stripping the +# info makes the inputs to compilation consistent, and allows for overwriting +# the compilation info with generated TD specs during codegen. +def strip_root_op_attr(module: ir.Module): + root_ops: list[ir.Operation] = get_ops_from_module(module, is_root_op) + for root_op in root_ops: + assert ( + ROOT_OP_ATTR_NAME in root_op.opview.attributes + ), f"expected root op to have '{ROOT_OP_ATTR_NAME}' attr" + del root_op.opview.attributes[ROOT_OP_ATTR_NAME] + + +# See the above comment for `strip_root_op_attr`. +def strip_compilation_info(input_path: Path) -> str: + # Strip compilation info from the source and save the stripped IR + strip_command = [ + f"iree-opt", + f"{input_path}", + f"--iree-codegen-strip-compilation-info", + ] + result = run_command( + RunPack( + command=strip_command, + check=True, + ) + ) + assert ( + result.process_res is not None + ), "expected result from stripping compilation info" + return result.process_res.stdout + + def main(): parser = argparse.ArgumentParser() parser.add_argument("input", help="Input mlir file", type=str) @@ -703,15 +345,6 @@ def main(): type=int, default=-1, ) - parser.add_argument( - "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" - ) - parser.add_argument( - "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" - ) - parser.add_argument( - "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" - ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" ) @@ -727,20 +360,22 @@ def main(): console_handler.setFormatter(formatter) tune_logger.addHandler(console_handler) - # # Optionally, add a file handler to log to a file - # file_handler = logging.FileHandler("tune.log") - # file_handler.setFormatter(formatter) - # tune_logger.addHandler(file_handler) - - tune( - args.input, - args.output, - args.limit, - args.num_subgroups, - args.lhs_dims, - args.rhs_dims, - args.tile_dims, - ) + with ir.Context() as ctx: + tuner_ctx = TunerContext(ctx, tune_logger) + mlir_text = strip_compilation_info(args.input) + mlir_module = parse_mlir(mlir_text, tuner_ctx) + specs = generate_configs_and_td_specs( + mlir_module, + tuner_ctx, + args.limit, + args.num_subgroups, + ) + for candidate_num, spec in enumerate(specs): + spec_dir = Path(args.output) + spec_path = spec_dir / f"{candidate_num}_spec.mlir" + spec_dir.mkdir(parents=True, exist_ok=True) + with open(spec_path, "w") as f: + f.write(str(spec)) if __name__ == "__main__": diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index d135a8502..8b0ca58d3 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -32,12 +32,6 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: yield common.TunerContext(ctx, logger) -def remove_comments(mlir: str) -> str: - return "\n".join( - filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) - ) - - def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: context = tuner_ctx.mlir_ctx module_str = """ @@ -213,505 +207,3 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None: "gpu_pipeline_options = #iree_gpu.pipeline_options" in matcher_sequence_str ) - - -def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - "gpu_pipeline_options = #iree_gpu.pipeline_options", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - M, N, K = 2048, 1280, 1280 - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[8, 8, 0], - reduction=[0, 0, 8], - subgroup_m_count=16, - subgroup_n_count=16, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) - config_dict = common.get_translation_info_config(pipeline_options, 8) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [16, 16, 1], 16, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - problem_size = common.ProblemSize( - common.MatmulSize(M, N, K), - common.ShapedType([M, K], tuner_ctx.type.f16), - common.ShapedType([N, K], tuner_ctx.type.f16), - common.ShapedType([M, N], tuner_ctx.type.f32), - common.DispatchKind.mmt, - ) - tf_mlir = candidate_gen.MmtTuner().apply_params( - problem_size, mlir_template, compilation_info - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - modified = remove_comments(modified) - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" - in modified - ) - assert "workgroup = [8, 8, 0]" in modified - assert "reduction = [0, 0, 8]" in modified - assert ( - "gpu_pipeline_options = #iree_gpu.pipeline_options" - in modified - ) - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified - - -def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16 - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[n, oh, ow, oc, fh, fw, 0], - reduction=[0, 0, 0, 0, 0, 0, ic], - subgroup_m_count=1, - subgroup_n_count=4, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get( - reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( - iree_gpu.ReorderWorkgroupsStrategy.Transpose - ) - ) - config_dict = common.get_translation_info_config(pipeline_options, 2) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [256, 1, 1], 64, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - problem_size = common.ProblemSize( - common.MatmulSize(oh * ow, oc, fh * fw * ic), - common.ShapedType([n, oh + 2, ow + 2, oc], tuner_ctx.type.f16), - common.ShapedType([fh, fw, ic, oc], tuner_ctx.type.f16), - common.ShapedType([n, oh, ow, oc], tuner_ctx.type.f32), - common.DispatchKind.conv, - ) - tf_mlir = candidate_gen.ConvTuner().apply_params( - problem_size, mlir_template, compilation_info - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - modified = remove_comments(modified) - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in modified - ) - assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified - assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified - assert ( - "gpu_pipeline_options = #iree_gpu.pipeline_options>" - in modified - ) - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 2, subgroup_n_count = 2>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "*mnk" - problem_size = common.ProblemSize( - common.MatmulSize(2048, 3840, 1280), - common.ShapedType([2, 1024, 1280], tuner_ctx.type.f16), - common.ShapedType([3, 20, 64, 1280], tuner_ctx.type.f16), - common.ShapedType([3, 2, 20, 1024, 64], tuner_ctx.type.f32), - common.DispatchKind.contraction, - ) - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[1, 480, 384, 0], - reduction=[0, 0, 0, 32], - subgroup_m_count=1, - subgroup_n_count=4, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get() - config_dict = common.get_translation_info_config(pipeline_options, 2) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [256, 1, 1], 64, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, compilation_info - ) - - new_mlir = tf_mlir.modified - - assert new_mlir - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in new_mlir - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in new_mlir - ) - assert "workgroup = [1, 480, 384, 0]" in new_mlir - assert "reduction = [0, 0, 0, 32]" in new_mlir - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir - - -def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "bmnk" - problem_size = common.ProblemSize( - common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], tuner_ctx.type.f16), - common.ShapedType([64, 640, 320], tuner_ctx.type.f16), - common.ShapedType([64, 968, 320], tuner_ctx.type.f32), - common.DispatchKind.batch_matmul, - ) - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[1, 416, 320, 0], - reduction=[0, 0, 0, 128], - subgroup_m_count=2, - subgroup_n_count=2, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get() - config_dict = common.get_translation_info_config(pipeline_options, 2) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [128, 2, 1], 64, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, compilation_info - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - modified = remove_comments(modified) - - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "workgroup = [1, 416, 320, 0]" in modified - assert "reduction = [0, 0, 0, 128]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = common.ProblemSize( - common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], tuner_ctx.type.f16), - common.ShapedType([2, 640, 640], tuner_ctx.type.f16), - common.ShapedType([2, 4096, 640], tuner_ctx.type.f32), - common.DispatchKind.batch_mmt, - ) - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[1, 128, 64, 0], - reduction=[0, 0, 0, 128], - subgroup_m_count=2, - subgroup_n_count=2, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get() - config_dict = common.get_translation_info_config(pipeline_options, 2) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [128, 2, 1], 64, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, compilation_info - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert embeddable - assert modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "workgroup = [1, 128, 64, 0]" in modified - assert "reduction = [0, 0, 0, 128]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = common.ProblemSize( - common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), - common.ShapedType([2, 640, 640], tuner_ctx.type.i8), - common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), - common.DispatchKind.batch_mmt, - ) - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[1, 128, 64, 0], - reduction=[0, 0, 0, 128], - subgroup_m_count=2, - subgroup_n_count=2, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get() - config_dict = common.get_translation_info_config(pipeline_options, 4) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [128, 2, 1], 64, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, compilation_info - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified - modified = remove_comments(modified) - - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "workgroup = [1, 128, 64, 0]" in modified - assert "reduction = [0, 0, 0, 128]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "workgroup = [1, 128, 64, 0]" in embeddable - assert "reduction = [0, 0, 0, 128]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = common.ProblemSize( - common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), - common.ShapedType([640, 640], tuner_ctx.type.i8), - common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), - common.DispatchKind.broadcast_rhs_mmt, - ) - - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[1, 128, 64, 0], - reduction=[0, 0, 0, 128], - subgroup_m_count=2, - subgroup_n_count=2, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get() - config_dict = common.get_translation_info_config(pipeline_options, 4) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [128, 2, 1], 64, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - - tf_mlir = candidate_gen.ContractionTuner( - "mk", "nk", "mnk" - ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, compilation_info) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert ( - "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" - in modified - ) - modified = remove_comments(modified) - - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "workgroup = [1, 128, 64, 0]" in modified - assert "reduction = [0, 0, 0, 128]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "workgroup = [1, 128, 64, 0]" in embeddable - assert "reduction = [0, 0, 0, 128]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: - mlir_lines = [ - r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( - mlir_lines - ) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 78e3a8e9d..54051df47 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -45,12 +45,8 @@ def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): class DispatchKind(Enum): - conv = 1 - mmt = 2 - contraction = 3 - batch_mmt = 4 - batch_matmul = 5 - broadcast_rhs_mmt = 6 + conv = 0 + contraction = 1 @dataclass @@ -108,11 +104,10 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: a_type, b_type, c_type = mma_attr.abc_element_types if not isinstance(problem_size.res_type.element_type, type(c_type)): return False - if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if not isinstance( - problem_size.lhs_type.element_type, type(a_type) - ) or not isinstance(problem_size.rhs_type.element_type, type(b_type)): - return False + if not isinstance( + problem_size.lhs_type.element_type, type(a_type) + ) or not isinstance(problem_size.rhs_type.element_type, type(b_type)): + return False return True return list(filter(is_comptible, mma_intrinsics)) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 6157bb355..b23360ccc 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -123,11 +123,13 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([2048, 1280], tuner_ctx.type.f16), common.ShapedType([1280, 1280], tuner_ctx.type.f16), common.ShapedType([2048, 1280], tuner_ctx.type.f32), - common.DispatchKind.mmt, + common.DispatchKind.contraction, ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], ) == [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, @@ -140,9 +142,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([2048, 1280], tuner_ctx.type.i8), common.ShapedType([1280, 1280], tuner_ctx.type.i8), common.ShapedType([2048, 1280], tuner_ctx.type.i32), - common.DispatchKind.mmt, + common.DispatchKind.contraction, ), [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], @@ -151,38 +155,6 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ] - assert common.get_compatible_mfma_intrinsics( - common.ProblemSize( - common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], tuner_ctx.type.f32), - common.ShapedType([64, 640, 320], tuner_ctx.type.f32), - common.ShapedType([64, 968, 320], tuner_ctx.type.f32), - common.DispatchKind.batch_matmul, - ), - [ - iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, - iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, - ], - ) == [ - iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, - iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, - ] - - assert common.get_compatible_mfma_intrinsics( - common.ProblemSize( - common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], tuner_ctx.type.f32), - common.ShapedType([64, 640, 320], tuner_ctx.type.f32), - common.ShapedType([64, 968, 320], tuner_ctx.type.f32), - common.DispatchKind.batch_matmul, - ), - [ - iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, - ], - ) == [ - iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, - ] - assert ( common.get_compatible_mfma_intrinsics( common.ProblemSize( @@ -190,9 +162,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([64, 968, 640], tuner_ctx.type.f32), common.ShapedType([64, 640, 320], tuner_ctx.type.f32), common.ShapedType([64, 968, 320], tuner_ctx.type.f32), - common.DispatchKind.batch_matmul, + common.DispatchKind.contraction, ), [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 914c04bbf..f6de5179d 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -232,15 +232,16 @@ def generate_solutions( problem_size.lhs_type.element_type, problem_size.rhs_type.element_type, ) + workgroup_tiles = [lookup(m), lookup(n), 0] + reduction_tiles = [0, 0, lookup(k)] + if problem_size.dispatch_kind == DispatchKind.conv: + workgroup_tiles = [1, 1, lookup(m), lookup(n), 0, 0, 0] + reduction_tiles = [0, 0, 0, 0, 1, 1, lookup(k)] lowering_config = get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[lookup(m), lookup(n), 0], - reduction=[ - 0, - 0, - lookup(k), - ], + workgroup=workgroup_tiles, + reduction=reduction_tiles, subgroup_m_count=lookup(sg_m_cnt), subgroup_n_count=lookup(sg_n_cnt), ) diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 842ea8509..5c82f555f 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -36,7 +36,7 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: rhs_type = common.ShapedType([3840, 1280], tuner_ctx.type.f16) res_type = common.ShapedType([2048, 3840], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction ) configs = dispatch_constraints.generate_solutions( tuner_ctx, @@ -59,7 +59,7 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( @@ -70,7 +70,7 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i8) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( @@ -81,7 +81,7 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( @@ -97,7 +97,7 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction ) # Define input parameters as z3 Ints m, n, k = ( @@ -149,7 +149,7 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction ) m, n, k = ( z3.Int("m"), diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 735d6145c..502968ea8 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -7,65 +7,12 @@ # Given an input dispatch, this code modifies the hyperparameters # in the code and runs it. -import math -import re from abc import ABCMeta, abstractmethod from .op_matchers import * from .common import * -def parse_tensor_type(tensor_type: str) -> ShapedType: - shaped_ty = ir.RankedTensorType(ir.Type.parse(tensor_type)) - assert shaped_ty - return ShapedType(shaped_ty.shape, shaped_ty.element_type) - - -def get_contract_workgroup_sizes( - configuration: iree_codegen.CompilationInfoAttr, tile_dims: str -) -> list[int]: - m, n, _k = configuration.lowering_config.workgroup_tile_sizes - - workgroup_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - workgroup_size[idx] = m - if dim == "n": - workgroup_size[idx] = n - if dim == "k": - workgroup_size[idx] = 0 - - return workgroup_size - - -def get_contract_reduction_sizes( - configuration: iree_codegen.CompilationInfoAttr, tile_dims: str -) -> list[int]: - _m, _n, k = configuration.lowering_config.reduction_tile_sizes - reduction_size = [0] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "k": - reduction_size[idx] = k - - return reduction_size - - -class MlirRegex(Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<([^>]+)>" - - def __str__(self) -> str: - return self.value - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: mlir_module = None try: @@ -179,359 +126,3 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_type=ShapedType(res_type.shape, res_type.element_type), dispatch_kind=DispatchKind.conv, ) - - -class MmtParser(DispatchParser): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert False, f"'{mmt_re}' not found in given context" - - -class ConvParser(DispatchParser): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - - -class ContractionParser(DispatchParser): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - - -class BatchMmtParser(DispatchParser): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - - -class BatchMatmulParser(DispatchParser): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 0b87be659..c35b17bed 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -32,15 +32,6 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: yield common.TunerContext(ctx, logger) -def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: - assert dispatch_parser.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( - [1, 2, 3], tuner_ctx.type.f32 - ) - assert dispatch_parser.parse_tensor_type("tensor<123xi8>") == common.ShapedType( - [123], tuner_ctx.type.i8 - ) - - CONTRACTION_TEMPLATE = r""" builtin.module{{ func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{ @@ -207,151 +198,6 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: ] -def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config = common.get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=[4, 8, 0], - reduction=[0, 0, 16], - subgroup_m_count=1, - subgroup_n_count=1, - ) - pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute - ) - pipeline_options = iree_gpu.PipelineOptionsAttr.get() - config_dict = common.get_translation_info_config(pipeline_options, 2) - translation_info = iree_codegen.TranslationInfoAttr.get( - pipeline_attr, None, [16, 16, 1], 32, config_dict - ) - compilation_info = iree_codegen.CompilationInfoAttr.get( - lowering_config, translation_info - ) - assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "mnk") == [ - 4, - 8, - 0, - ] - assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "mnk") == [ - 0, - 0, - 16, - ] - assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "nmk") == [ - 8, - 4, - 0, - ] - assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "nmk") == [ - 0, - 0, - 16, - ] - assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "knm") == [ - 0, - 8, - 4, - ] - assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "knm") == [ - 16, - 0, - 0, - ] - assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "kkk") == [ - 0, - 0, - 0, - ] - assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "kkk") == [ - 16, - 16, - 16, - ] - - -def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( - common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], tuner_ctx.type.f16), - common.ShapedType([1280, 1280], tuner_ctx.type.f16), - common.ShapedType([2048, 1280], tuner_ctx.type.f32), - dispatch_parser.DispatchKind.mmt, - ) - - -def test_get_shapes_conv(tuner_ctx: common.TunerContext) -> None: - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( - common.MatmulSize(32, 256, 11520), - common.ShapedType([1, 3, 34, 1280], tuner_ctx.type.f16), - common.ShapedType([3, 3, 1280, 256], tuner_ctx.type.f16), - common.ShapedType([1, 1, 32, 256], tuner_ctx.type.f32), - dispatch_parser.DispatchKind.conv, - ) - - -def test_get_shapes_contract(tuner_ctx: common.TunerContext) -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert dispatch_parser.ContractionParser("mk", "nk", "mnk").get_shapes( - template - ) == common.ProblemSize( - common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], tuner_ctx.type.f16), - common.ShapedType([1280, 1280], tuner_ctx.type.f16), - common.ShapedType([2048, 1280], tuner_ctx.type.f32), - dispatch_parser.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul(tuner_ctx: common.TunerContext) -> None: - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert dispatch_parser.BatchMatmulParser("bmk", "bkn", "mnk").get_shapes( - template - ) == common.ProblemSize( - common.MatmulSize(32, 32, 1024, 1), - common.ShapedType([1, 32, 1024], tuner_ctx.type.f32), - common.ShapedType([1, 1024, 32], tuner_ctx.type.f32), - common.ShapedType([1, 32, 32], tuner_ctx.type.f32), - dispatch_parser.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt(tuner_ctx: common.TunerContext) -> None: - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( - common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), - common.ShapedType([2, 640, 640], tuner_ctx.type.i8), - common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), - dispatch_parser.DispatchKind.batch_mmt, - ) - - def test_parse_mlir(tuner_ctx: common.TunerContext) -> None: mlir_str = r""" builtin.module { diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index ff7b78a11..4e2a97ec8 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -6,11 +6,9 @@ """ Provides fundamental functions for tuning: - - generate_candidates() - - compile_dispatches() - - benchmark_dispatches() - - compile_models() - - benchmark_models() + - generate_candidate_specs() + - compile() + - benchmark() Requires a wrapper Python script to import `libtuner`, use the `TuningClient` API, customize compilation and benchmarking commands, @@ -20,9 +18,9 @@ import math import signal +import subprocess import sys import shutil -import subprocess import logging import argparse from datetime import datetime @@ -32,13 +30,9 @@ import multiprocessing import queue from tqdm import tqdm -import re import hashlib from dataclasses import dataclass, field from typing import Type, Optional, Callable, Iterable, Any -import pickle -import random -import json from abc import ABC, abstractmethod import iree.runtime as ireert # type: ignore import iree.compiler as ireec # type: ignore @@ -66,78 +60,32 @@ DEVICE_ID_PLACEHOLDER = "!DEVICE_ID!" -# TODO(Max191): Remove most of the fields here after refactoring is complete, -# since many of them will be unused. @dataclass class CandidateTracker: candidate_id: int mlir_path: Optional[Path] = None - dispatch_mlir_path: Optional[Path] = None - dispatch_config_path: Optional[Path] = None - configuration: Optional[candidate_gen.iree_codegen.CompilationInfoAttr] = None - compilation_successful: Optional[bool] = None compiled_vmfb_path: Optional[Path] = None - compiled_dispatch_path: Optional[Path] = None - compiled_dispatch_hash: Optional[str] = None - first_benchmark_time: Optional[float] = None - first_benchmark_device_id: Optional[str] = None spec_path: Optional[Path] = None - compiled_model_path: Optional[Path] = None - compiled_model_hash: Optional[str] = None - model_benchmark_time: Optional[float] = None - model_benchmark_device_id: Optional[str] = None - baseline_benchmark_time: Optional[float] = None - calibrated_benchmark_diff: Optional[float] = None @dataclass() class PathConfig: - # Preset constants - global_config_prolog_mlir: Path = Path("config_prolog.mlir") - global_config_epilog_mlir: Path = Path("config_epilog.mlir") - model_baseline_vmfb: Path = Path("baseline.vmfb") - # Dynamic paths base_dir: Path = field(init=False) - local_config_prolog_mlir: Path = field(init=False) - local_config_epilog_mlir: Path = field(init=False) template_mlir: Path = field(init=False) candidates_dir: Path = field(init=False) - candidate_configs_pkl: Path = field(init=False) compiled_dir: Path = field(init=False) - compile_failed_dir: Path = field(init=False) specs_dir: Path = field(init=False) - output_unilog: Path = field(init=False) - result_summary_log: Path = field(init=False) - candidate_trackers_pkl: Path = field(init=False) - # To be set outside of class run_log: Optional[Path] = field(init=False, default=None) def __post_init__(self): object.__setattr__(self, "base_dir", self._name_base_dir()) - object.__setattr__( - self, "local_config_prolog_mlir", self.base_dir / "config_prolog.mlir" - ) - object.__setattr__( - self, "local_config_epilog_mlir", self.base_dir / "config_epilog.mlir" - ) object.__setattr__(self, "template_mlir", self.base_dir / "template.mlir") object.__setattr__(self, "candidates_dir", self.base_dir / "candidates") - object.__setattr__( - self, "candidate_configs_pkl", self.candidates_dir / "configs.pkl" - ) object.__setattr__(self, "compiled_dir", self.candidates_dir / "compiled") - object.__setattr__(self, "compile_failed_dir", self.candidates_dir / "failed") object.__setattr__(self, "specs_dir", self.candidates_dir / "specs") - object.__setattr__(self, "output_unilog", self.base_dir / "output.log") - object.__setattr__( - self, "result_summary_log", self.base_dir / "result_summary.log" - ) - object.__setattr__( - self, "candidate_trackers_pkl", self.base_dir / "candidate_trackers.pkl" - ) def _name_base_dir(self) -> Path: timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M") @@ -147,27 +95,12 @@ def _name_base_dir(self) -> Path: def _set_run_log(self, run_log: Path): object.__setattr__(self, "run_log", run_log) - def get_candidate_mlir_path(self, candidate_id: int) -> Path: - return self.candidates_dir / f"{candidate_id}.mlir" - - def get_candidate_spec_mlir_path(self, candidate_id: int) -> Path: - return self.candidates_dir / "specs" / f"{candidate_id}_spec.mlir" - - def get_exe_format(self, path: Path) -> str: - return f"./{path.as_posix()}" - - def get_compiled_dispatch_index(self, file_path: Path) -> int: - return int(file_path.stem) - def get_candidate_spec_filename(self, candidate_id: int) -> str: return f"{candidate_id}_spec.mlir" def get_candidate_vmfb_filename(self, candidate_id: int) -> str: return f"{candidate_id}.vmfb" - def get_compiled_model_index(self, file_path: Path) -> int: - return int(file_path.stem.split("_")[-1]) - class TuningClient(ABC): def __init__(self): @@ -183,50 +116,10 @@ def get_iree_compile_flags(self) -> list[str]: def get_iree_benchmark_module_flags(self) -> list[str]: pass - @abstractmethod - def get_dispatch_compile_command( - self, candidate_tracker: CandidateTracker - ) -> list[str]: - pass - - @abstractmethod - def get_dispatch_benchmark_command( - self, candidate_tracker: CandidateTracker - ) -> list[str]: - pass - - @abstractmethod - def get_model_compile_command( - self, candidate_tracker: CandidateTracker - ) -> list[str]: - pass - - @abstractmethod - def get_model_benchmark_command( - self, candidate_tracker: CandidateTracker - ) -> list[str]: - pass - @abstractmethod def get_benchmark_timeout_s(self) -> int: pass - @abstractmethod - def get_dispatch_compile_timeout_s(self) -> int: - pass - - @abstractmethod - def get_dispatch_benchmark_timeout_s(self) -> int: - pass - - @abstractmethod - def get_model_compile_timeout_s(self) -> int: - pass - - @abstractmethod - def get_model_benchmark_timeout_s(self) -> int: - pass - @dataclass class CompilePack: @@ -241,42 +134,6 @@ class BenchmarkPack: candidate_tracker: CandidateTracker -@dataclass -class RunPack: - command: list[str] - check: bool = True - timeout_seconds: Optional[int] = None - - -@dataclass -class RunResult: - process_res: Optional[subprocess.CompletedProcess] - is_timeout: bool - - -@dataclass -class TaskPack: - run_pack: RunPack - candidate_id: int - command_need_device_id: bool = False - cooling_time: int = 0 - - -@dataclass -class TaskResult: - run_result: RunResult - candidate_id: int - device_id: str - - -@dataclass -class ParsedDisptachBenchmarkResult: - candidate_id: int - benchmark_time_in_seconds: float - candidate_mlir: Path - candidate_spec_mlir: Path - - @dataclass class BenchmarkResult: candidate_id: int @@ -284,75 +141,17 @@ class BenchmarkResult: device_id: str -@dataclass -class IREEBenchmarkResult: - # Default format follows output of iree-benchmark-module - candidate_id: int - - # A list of dictionaries, each representing a benchmark result - # Each dictionary contains fields like: aggregate_name: string, real_time: float, cpu_time: float, time_unit: str, repetitions: int, etc. - result_json: list[dict[str, Any]] - - def get_mean_time_us(self) -> Optional[float]: - """Compute the mean time (in microseconds) for all of the benchmarks""" - if not self.result_json: - return None +def unit_to_microseconds(real_time: float, time_unit: str) -> float: + unit_conversions = { + "s": 1e6, + "ms": 1e3, + "us": 1, + "ns": 1e-3, + } - mean_benchmark = self.find_mean_benchmark(self.result_json) + assert time_unit in unit_conversions, f"Unsupported time unit: {time_unit}" - if mean_benchmark: - real_time: float | None = mean_benchmark.get("real_time") - time_unit: str | None = mean_benchmark.get("time_unit") - - if real_time is not None: - assert time_unit is not None - return self.unit_to_microseconds(real_time, time_unit) - - return None - - @staticmethod - def find_mean_benchmark(result_json: list[dict[str, Any]]) -> Optional[dict]: - for benchmark in result_json: - if benchmark.get("aggregate_name") == "mean": - return benchmark - - return None - - @staticmethod - def unit_to_microseconds(real_time: float, time_unit: str) -> float: - unit_conversions = { - "s": 1e6, - "ms": 1e3, - "us": 1, - "ns": 1e-3, - } - - assert time_unit in unit_conversions, f"Unsupported time unit: {time_unit}" - - return real_time * unit_conversions[time_unit] - - -def generate_display_DBR(candidate_id: int, mean_time: float) -> str: - """Generate dispatch_benchmark_result string for displaying""" - return f"{candidate_id}\tMean Time: {mean_time:.1f}" - - -def generate_display_MBR( - candidate_vmfb_path_str: str, - device_id: str, - t1: float, - calibrated_diff: Optional[float] = None, -) -> str: - """Generate model_benchmark_result string for displaying""" - if calibrated_diff: - percentage_change = calibrated_diff * 100 - change_str = f"({percentage_change:+.3f}%)" - res_str = f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g} {change_str}" - else: - res_str = ( - f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g}" - ) - return res_str + return real_time * unit_conversions[time_unit] def extract_driver_names(user_devices: list[str]) -> set[str]: @@ -605,85 +404,6 @@ def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int, return worker_contexts_queue -def run_command(run_pack: RunPack) -> RunResult: - command = run_pack.command - check = run_pack.check - timeout_seconds = run_pack.timeout_seconds - - result = None - is_timeout = False - try: - # Convert the command list to a command string for logging - command_str = " ".join(command) - logging.debug(f"Run: {command_str}") - - # Add timeout to subprocess.run call - result = subprocess.run( - command, - check=check, - capture_output=True, - text=True, - timeout=timeout_seconds, - ) - - if result.stdout: - logging.debug(f"stdout: {result.stdout}") - if result.stderr: - logging.debug(f"stderr: {result.stderr}") - except subprocess.TimeoutExpired as e: - logging.warning( - f"Command '{command_str}' timed out after {timeout_seconds} seconds." - ) - is_timeout = True - except subprocess.CalledProcessError as e: - print(e.output) - logging.error( - f"Command '{command_str}' returned non-zero exit status {e.returncode}." - ) - logging.error(f"Command '{command_str}' failed with error: {e.stderr}") - if check: - raise - except KeyboardInterrupt: - print("Ctrl+C detected, terminating child processes...") - - return RunResult(result, is_timeout) - - -# The `strip_root_op_attr` and `strip_compilation_info` functions are used for -# getting consistent inputs to the compilation step in tuning. Inputs may come -# in with lowering configs, translation info, and root_op attrs when the input -# is a benchmark, but not when the input is a source MLIR file. Stripping the -# info makes the inputs to compilation consistent, and allows for overwriting -# the compilation info with generated TD specs during codegen. -def strip_root_op_attr(module: ir.Module): - root_ops: list[ir.Operation] = get_ops_from_module(module, is_root_op) - for root_op in root_ops: - assert ( - ROOT_OP_ATTR_NAME in root_op.opview.attributes - ), f"expected root op to have '{ROOT_OP_ATTR_NAME}' attr" - del root_op.opview.attributes[ROOT_OP_ATTR_NAME] - - -# See the above comment for `strip_root_op_attr`. -def strip_compilation_info(input_path: Path) -> str: - # Strip compilation info from the source and save the stripped IR - strip_command = [ - f"iree-opt", - f"{input_path}", - f"--iree-codegen-strip-compilation-info", - ] - result = run_command( - RunPack( - command=strip_command, - check=True, - ) - ) - assert ( - result.process_res is not None - ), "expected result from stripping compilation info" - return result.process_res.stdout - - def run_iree_compile_command(compile_pack: CompilePack) -> Optional[int]: candidate_tracker = compile_pack.candidate_tracker @@ -790,11 +510,21 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack): assert ( len(time_and_unit) == 2 ), "expected the benchmark time to be the time and unit separated by a space." - time_us = IREEBenchmarkResult.unit_to_microseconds( + time_us = unit_to_microseconds( real_time=float(time_and_unit[0]), time_unit=time_and_unit[1], ) times.append(time_us) + + # If there are no times, then benchmarking failed at runtime. Record the + # time as math.inf. + if len(times) == 0: + return BenchmarkResult( + candidate_id=candidate_id, + time=math.inf, + device_id=str(device_id), + ) + mean_benchmark_time = sum(times) / float(len(times)) logging.debug(f"Benchmark time of candidate {candidate_id}: {mean_benchmark_time}") return BenchmarkResult( @@ -804,30 +534,6 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack): ) -def run_command_wrapper(task_pack: TaskPack) -> TaskResult: - """Help handle extra requirements and record more data for run_command()""" - if task_pack.command_need_device_id: - # Worker searches for the special symbol and substitutes it with the actual device_id - pattern = re.compile(re.escape(DEVICE_ID_PLACEHOLDER)) - task_pack.run_pack.command = [ - pattern.sub(str(device_id), s) for s in task_pack.run_pack.command - ] - - run_result = run_command(task_pack.run_pack) - - task_result = TaskResult( - run_result, task_pack.candidate_id, device_id=str(-1) - ) # Main process - if device_id: - task_result = TaskResult( - run_result, task_pack.candidate_id, device_id - ) # Subprocess - - time.sleep(task_pack.cooling_time) - - return task_result - - def multiprocess_progress_wrapper( num_worker: int, task_list: list, @@ -861,44 +567,6 @@ def multiprocess_progress_wrapper( return results -def extract_benchmark_from_run_result( - run_result: RunResult, -) -> Optional[list[dict[str, Any]]]: - """Extract the benchmark from the result JSON""" - if run_result.process_res and run_result.process_res.stdout: - try: - result_json = json.loads(run_result.process_res.stdout) - - return result_json.get("benchmarks", None) - except json.JSONDecodeError as e: - handle_error( - condition=True, - msg=f"Failed to parse JSON from stdout: {e}", - error_type=ValueError, - exit_program=True, - ) - - return None - - -def numerical_sort_key(path: Path) -> tuple[int | float, str]: - """ - Define a sort key function that splits the filename into a numeric and a string part. - Order: 0 | 0_a | 0_b | 1 | 1_a | 2 - """ - numeric_part: int | float - # Extract the numeric part at the start of the filename - match = re.match(r"(\d+)", path.stem) - if match: - numeric_part = int(match.group(1)) - # The rest of the filename after the numeric part - remaining_part = path.stem[len(match.group(0)) :] - else: - numeric_part = float("inf") - remaining_part = path.stem - return (numeric_part, remaining_part) - - def calculate_md5(file_path: Path) -> str: md5 = hashlib.md5() with open(file_path, "rb") as f: @@ -933,111 +601,6 @@ def find_collisions( return collisions_exist, hash_values -def load_pickle(file_path: Path) -> list[Any]: - handle_error( - condition=(not file_path.exists()), - msg=f"Configuration file not found: {file_path}", - error_type=FileNotFoundError, - ) - with open(file_path, "rb") as file: - loaded_array = pickle.load(file) - return loaded_array - - -def save_pickle(file_path: Path, input_list: list[Any]) -> None: - with open(file_path, "wb") as file: - pickle.dump(input_list, file) - - -def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None: - """Appends new content to the end of the output.log.""" - title_str = "=" * 5 + f" {title} " + "=" * 5 + "\n" if title != "" else "" - with open(filepath, "a") as file: - file.write(title_str) - file.writelines(lines) - file.write("\n") - - -# TODO(Max191): Remove in favor of using generate_candidate_specs. -def generate_candidates( - args: argparse.Namespace, - path_config: PathConfig, - candidate_trackers: list[CandidateTracker], -) -> list[int]: - """Generate candidate files for tuning. Returns the list of candidate indexes""" - logging.debug("generate_candidates()") - - try: - shutil.copy( - path_config.global_config_epilog_mlir, path_config.local_config_epilog_mlir - ) - shutil.copy( - path_config.global_config_prolog_mlir, path_config.local_config_prolog_mlir - ) - except FileNotFoundError as e: - handle_error( - condition=True, - msg=f"Configuration file not found: {e}", - error_type=FileNotFoundError, - ) - - shutil.copy(args.input_file, path_config.template_mlir) - - mlirs = [] - try: - logging.debug("Captured messages from candidate_gen.py:") - candidate_gen.tune( - input=str(path_config.template_mlir), - output=str(path_config.candidates_dir), - limit=args.num_candidates, - num_subgroups=args.num_subgroups, - lhs_dims=args.lhs_dims, - rhs_dims=args.rhs_dims, - tile_dims=args.tile_dims, - ) - mlirs = sorted( - path_config.candidates_dir.glob("*.mlir"), key=numerical_sort_key - ) - except Exception as e: - logging.error("An error occurred during candidates generation: %s", str(e)) - # Capture and log debug messages from candidate_gen.py - tune_logger = logging.getLogger("tune") - for handler in logging.getLogger().handlers: - if isinstance(handler, logging.FileHandler): - tune_logger.handlers.append(handler) - tune_logger.exception("Error in candidate_gen.py:") - raise - logging.debug("candidate_gen.py ends") - - candidate_configs = load_pickle(path_config.candidate_configs_pkl) - candidate_configs.insert(0, None) # No Configuration class for 0.mlir - - # Create candidate trackers - assert len(mlirs) // 2 + 1 == len(candidate_configs) - candidates = [] - for mlir in mlirs: - if "_config.mlir" not in mlir.name: - candidates.append(int(mlir.stem)) - new_candidate = CandidateTracker( - candidate_id=int(mlir.stem), - dispatch_mlir_path=mlir, - configuration=candidate_configs[int(mlir.stem)], - ) - candidate_trackers.append(new_candidate) - else: - candidate_trackers[ - int(mlir.stem.split("_config")[0]) - ].dispatch_config_path = mlir - - handle_error( - condition=(len(candidates) == 0), msg="Failed to generate any candidates" - ) - - logging.info(f"Generated [{len(candidates)}] candidates") - - return candidates - - def generate_candidate_specs( args: argparse.Namespace, path_config: PathConfig, @@ -1056,7 +619,7 @@ def generate_candidate_specs( # Strip compilation info before generating td_specs, since the generated # td_specs can end up matching against the compilation info from the # source mlir. - mlir_text = strip_compilation_info(path_config.template_mlir) + mlir_text = candidate_gen.strip_compilation_info(path_config.template_mlir) mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context) with tuning_client.tuner_context.mlir_ctx: logging.debug("Captured messages from candidate_gen.py:") @@ -1140,10 +703,10 @@ def compile( # Strip compilation info and root_op attribute from the source and save # the stripped IR, since the TD specs do not expect these attributes. - stripped_mlir = strip_compilation_info(path_config.template_mlir) + stripped_mlir = candidate_gen.strip_compilation_info(path_config.template_mlir) context = tuning_client.tuner_context.mlir_ctx stripped_module = ir.Module.parse(stripped_mlir, context=context) - strip_root_op_attr(stripped_module) + candidate_gen.strip_root_op_attr(stripped_module) stripped_mlir = str(stripped_module) with open(path_config.template_mlir, "w") as f: f.write(stripped_mlir) @@ -1200,273 +763,6 @@ def compile( return compiled_candidates -# TODO(Max191): Remove in favor of using `compile` for both model and dispatch -# tuning. -def compile_dispatches( - args: argparse.Namespace, - path_config: PathConfig, - candidates: list[int], - candidate_trackers: list[CandidateTracker], - tuning_client: TuningClient, -) -> list[int]: - logging.debug("compile_dispatches()") - - if not candidates: - logging.warning("No candidates to compile.") - return [] - - path_config.compiled_dir.mkdir(parents=True, exist_ok=True) - path_config.compile_failed_dir.mkdir(parents=True, exist_ok=True) - path_config.specs_dir.mkdir(parents=True, exist_ok=True) - - task_list = [ - TaskPack( - RunPack( - command=tuning_client.get_dispatch_compile_command( - candidate_trackers[i] - ), - check=False, - timeout_seconds=tuning_client.get_dispatch_compile_timeout_s(), - ), - candidate_id=i, - ) - for i in candidates - ] - num_worker = min(args.max_cpu_workers, len(task_list)) - multiprocess_progress_wrapper( - num_worker=num_worker, task_list=task_list, function=run_command_wrapper - ) - - # Note: failed/incomplete candidates can also be detected by checking if subprocess.res is None - compiled_files = sorted( - path_config.compiled_dir.glob("*.vmfb"), key=numerical_sort_key - ) - failed_files = sorted( - path_config.compile_failed_dir.glob("*.mlir"), key=numerical_sort_key - ) - - total, good, bad = len(task_list), len(compiled_files), len(failed_files) - compiling_rate = good / total * 100 - logging.info( - f"Total: {total} | Compiled: {good} | Failed: {bad} | Compiling Rate: {compiling_rate:.1f}%" - ) - - # Update candidate tracker - for failed_file in failed_files: - index = path_config.get_compiled_dispatch_index(failed_file) - candidate_trackers[index].compilation_successful = False - compiled_candidates = [] - compiled_candidates_hash_list = [] - for compiled_file in compiled_files: - index = path_config.get_compiled_dispatch_index(compiled_file) - compiled_candidates.append(index) - candidate_trackers[index].compilation_successful = True - candidate_trackers[index].compiled_dispatch_path = compiled_file - compiled_vmfb_path = candidate_trackers[index].compiled_dispatch_path - assert compiled_vmfb_path is not None - hash_val = calculate_md5(compiled_vmfb_path) - candidate_trackers[index].compiled_dispatch_hash = hash_val - compiled_candidates_hash_list.append((index, hash_val)) - - handle_error( - condition=(good == 0), - msg="All candidate dispatches .mlir files failed to compile", - ) - handle_error( - condition=(compiling_rate < 10), - msg=f"Compiling rate [{compiling_rate:.1f}%] < 10%", - level=logging.WARNING, - ) - - collision_detected, unique_indexes = collision_handler( - compiled_candidates_hash_list - ) - if collision_detected: - logging.info(f"Remains [{len(unique_indexes)}] unique candidate indexes") - - return compiled_candidates if not collision_detected else unique_indexes - - -def parse_dispatch_benchmark_results( - path_config: PathConfig, - benchmark_results: list[TaskResult], - candidate_trackers: list[CandidateTracker], -) -> tuple[list[ParsedDisptachBenchmarkResult], list[str]]: - benchmark_result_configs = [] - dump_list = [] - incomplete_list = [] - - for benchmark_result in benchmark_results: - candidate_id = benchmark_result.candidate_id - process_res = benchmark_result.run_result.process_res - - if not process_res: - if benchmark_result.run_result.is_timeout: - incomplete_list.append(candidate_id) - continue - - res_json = extract_benchmark_from_run_result(benchmark_result.run_result) - assert res_json is not None - res = IREEBenchmarkResult(candidate_id, res_json) - benchmark_time = res.get_mean_time_us() - assert benchmark_time is not None - candidate_trackers[candidate_id].first_benchmark_time = benchmark_time - candidate_trackers[ - candidate_id - ].spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename( - candidate_id - ) - mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path - spec_path = candidate_trackers[candidate_id].spec_path - assert mlir_path is not None and spec_path is not None - dump_list.append(generate_display_DBR(candidate_id, benchmark_time) + "\n") - - benchmark_result_configs.append( - ( - ParsedDisptachBenchmarkResult( - candidate_id, - benchmark_time, - mlir_path, - spec_path, - ) - ) - ) - - if incomplete_list: - dump_list += [f"Candidate {i} not completed" for i in incomplete_list] - - return benchmark_result_configs, dump_list - - -def generate_sample_task_result( - stdout: str, candidate_id: int, device_id: str -) -> TaskResult: - res = subprocess.CompletedProcess( - args=[""], - stdout=stdout, - returncode=0, - ) - run_result = RunResult(res, False) - return TaskResult( - run_result=run_result, candidate_id=candidate_id, device_id=device_id - ) - - -def generate_dryrun_dispatch_benchmark_results( - compiled_candidates: list[int], -) -> list[TaskResult]: - logging.debug("generate_dryrun_dispatch_benchmark_results()") - - task_results = [ - generate_sample_task_result( - f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", - i, - str(0), - ) - for i in compiled_candidates - ] - - return task_results - - -def generate_dryrun_model_benchmark_results( - model_candidates: list[int], -) -> tuple[list[TaskResult], list[TaskResult]]: - candidate_results = [] - for i, j in enumerate(model_candidates): - stdout = f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms" - candidate_results.append(generate_sample_task_result(stdout, j, str(i % 3))) - - baseline_results = [ - generate_sample_task_result( - f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", - 0, - str(i), - ) - for i in range(3) - ] - - return candidate_results, baseline_results - - -# TODO(Max191): Remove this function in favor of `benchmark`. -def benchmark_dispatches( - args: argparse.Namespace, - path_config: PathConfig, - compiled_candidates: list[int], - candidate_trackers: list[CandidateTracker], - tuning_client: TuningClient, -): - logging.debug("benchmark_dispatches()") - - if args.dry_run: - benchmark_results = generate_dryrun_dispatch_benchmark_results( - compiled_candidates - ) - else: - # Benchmarking dispatch candidates - task_list = [ - TaskPack( - RunPack( - command=tuning_client.get_dispatch_benchmark_command( - candidate_trackers[i] - ), - check=False, - timeout_seconds=tuning_client.get_dispatch_benchmark_timeout_s(), - ), - candidate_id=i, - command_need_device_id=True, - ) - for i in compiled_candidates - ] - worker_context_queue = create_worker_context_queue(args.devices) - benchmark_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=task_list, - function=run_command_wrapper, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), - ) - - ( - parsed_benchmark_results, - dispatch_benchmark_dump_list, - ) = parse_dispatch_benchmark_results( - path_config, benchmark_results, candidate_trackers - ) - append_to_file( - dispatch_benchmark_dump_list, - filepath=path_config.output_unilog, - title="All Dispatch Benchmark Results", - ) - - benchmarking_rate = (len(parsed_benchmark_results) / len(benchmark_results)) * 100 - logging.info( - f"Total: {len(benchmark_results)} | Benchmarked: {len(parsed_benchmark_results)} | Failed: {len(benchmark_results) - len(parsed_benchmark_results)} | Benchmarking Rate: {benchmarking_rate:.1f}%" - ) - handle_error( - condition=(len(benchmark_results) == 0), - msg="Failed to benchmark all candidate .vmfb files", - ) - - # Select top candidates - best_results = sorted( - parsed_benchmark_results, key=lambda x: float(x.benchmark_time_in_seconds) - )[: args.num_model_candidates] - logging.info(f"Selected top[{len(best_results)}]") - - dump_list = [ - f"{result.benchmark_time_in_seconds}\t{result.candidate_mlir.as_posix()}\t{result.candidate_spec_mlir.as_posix()}\n" - for result in best_results - ] - append_to_file( - dump_list, filepath=path_config.output_unilog, title="Top Candidates Results" - ) - - top_candidates = [result.candidate_id for result in best_results] - return top_candidates - - def benchmark( args: argparse.Namespace, path_config: PathConfig, @@ -1533,317 +829,3 @@ def get_speedup(result: BenchmarkResult) -> float: top_candidates = [result.candidate_id for result in best_results] return top_candidates - - -# TODO(Max191): Remove in favor of using `compile` for both model and dispatch -# tuning. -def compile_models( - args: argparse.Namespace, - path_config: PathConfig, - candidates: list[int], - candidate_trackers: list[CandidateTracker], - tuning_client: TuningClient, -) -> list[int]: - logging.debug("compile_models()") - - candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb - - if args.dry_run: - for i in candidates: - candidate_trackers[i].compiled_model_path = Path(f"model_{i}.vmfb") - return candidates - - if not candidates: - logging.warning("No model candidates to compile.") - return [] - - task_list = [ - TaskPack( - RunPack( - command=tuning_client.get_model_compile_command(candidate_trackers[i]), - check=False, - timeout_seconds=tuning_client.get_model_compile_timeout_s(), - ), - candidate_id=i, - ) - for i in candidates - if i != 0 - ] - num_worker = min(args.max_cpu_workers, len(task_list)) - multiprocess_progress_wrapper( - num_worker=num_worker, task_list=task_list, function=run_command_wrapper - ) - - model_candidates_files = list(path_config.base_dir.glob("*.vmfb")) - - model_candidates_indexes = [] - model_candidates_hash_list = [] - - # Update candidate tracker - for model_candidate in model_candidates_files: - assert model_candidate is not None - index = path_config.get_compiled_model_index(model_candidate) - candidate_trackers[index].compiled_model_path = model_candidate - hash_val = calculate_md5(model_candidate) - candidate_trackers[index].compiled_model_hash = hash_val - model_candidates_hash_list.append((index, hash_val)) - model_candidates_indexes.append(index) - - # Check if model candidate produces tbe same .vmfb - collision_detected, unique_model_candidates_indexes = collision_handler( - model_candidates_hash_list - ) - - if collision_detected: - logging.info( - f"Remains [{len(unique_model_candidates_indexes)}] unique candidate indexes" - ) - - return ( - unique_model_candidates_indexes - if collision_detected - else model_candidates_indexes - ) - - -def group_benchmark_results_by_device_id( - benchmark_results: list[TaskResult], -) -> list[list[TaskResult]]: - """ - Groups benchmark results by device ID. - - e.g. - [TaskResult(res1, device_1), TaskResult(res2, device_2), TaskResult(res3, device_1)] - -----> - [ [TaskResult(res1, device_1), TaskResult(res3, device_1)], [TaskResult(res2, device_2)] ] - """ - grouped_results: dict[str, list[TaskResult]] = {} - for result in benchmark_results: - assert result.device_id is not None - if result.device_id not in grouped_results: - grouped_results[result.device_id] = [] - grouped_results[result.device_id].append(result) - - grouped_benchmark_results = [ - grouped_results[device_id] for device_id in sorted(grouped_results) - ] - - return grouped_benchmark_results - - -def parse_model_benchmark_results( - candidate_trackers: list[CandidateTracker], - candidate_results: list[TaskResult], - baseline_results: list[TaskResult], -): - """Update candidate_tracker and format a list of result strings to be saved later.""" - candidate_results = sorted(candidate_results, key=lambda br: br.device_id) - baseline_results = sorted(baseline_results, key=lambda tr: tr.device_id) - - # Assign candidates to the same groups by device_id - grouped_candidate_results = group_benchmark_results_by_device_id(candidate_results) - - # Insert baseline results to the head of each list - grouped_benchmark_results = [ - [x] + y for x, y in zip(baseline_results, grouped_candidate_results) - ] - - dump_list = [] - incomplete_list: list[ - tuple[int, Optional[str]] - ] = [] # format: [(candidate_id, device_id)] - - baseline_time = None - for same_device_results in grouped_benchmark_results: - dump_unsort_list: list[tuple[float, str]] = [] - for task_result in same_device_results: - candidate_id = task_result.candidate_id - device_id = task_result.device_id - process_res = task_result.run_result.process_res - - # Check if benchmarking has completed - if not process_res: - if task_result.run_result.is_timeout: - incomplete_list.append((candidate_id, device_id)) - if candidate_id == 0: - baseline_time = None - continue - - result_json = extract_benchmark_from_run_result(task_result.run_result) - assert result_json is not None - res = IREEBenchmarkResult(candidate_id, result_json) - benchmark_time = res.get_mean_time_us() - assert benchmark_time is not None - - # Record baseline benchmarking result and skip rest processes - if candidate_id == 0: - baseline_time = benchmark_time - baseline_vmfb_path = candidate_trackers[ - candidate_id - ].compiled_model_path - assert baseline_vmfb_path is not None - dump_str = ( - generate_display_MBR( - candidate_vmfb_path_str=baseline_vmfb_path.as_posix(), - device_id=device_id, - t1=benchmark_time, - ) - + "\n\n" - ) - dump_list.append(dump_str) - continue - - # Update candidate_tracker - candidate_trackers[candidate_id].model_benchmark_time = benchmark_time - candidate_trackers[candidate_id].model_benchmark_device_id = device_id - - # Calculate candidate improvement based on baseline. - if baseline_time: - candidate_trackers[candidate_id].baseline_benchmark_time = baseline_time - calibrated_benchmark_diff = ( - benchmark_time - baseline_time - ) / baseline_time - candidate_trackers[ - candidate_id - ].calibrated_benchmark_diff = calibrated_benchmark_diff - else: - calibrated_benchmark_diff = None - - # Collect candidate dump str - candidate_vmfb_path = candidate_trackers[candidate_id].compiled_model_path - assert candidate_vmfb_path is not None - dump_str = ( - generate_display_MBR( - candidate_vmfb_path_str=candidate_vmfb_path.as_posix(), - device_id=device_id, - t1=benchmark_time, - calibrated_diff=calibrated_benchmark_diff, - ) - + "\n\n" - ) - - dump_unsort_list.append((benchmark_time, dump_str)) - - # Sort model candidate benchmarking result str in ascending time order. - dump_list = dump_list + [ - dump_str for _, dump_str in sorted(dump_unsort_list, key=lambda x: x[0]) - ] - - # Store incomplete .vmfb file at the end of dump_list. - for index, device in incomplete_list: - file_path = candidate_trackers[index].compiled_model_path - assert file_path is not None - error_msg = f"Benchmarking result of {file_path.as_posix()} on device {device} is incomplete" - handle_error(condition=True, msg=error_msg, level=logging.WARNING) - dump_list.append(error_msg + "\n") - - return dump_list - - -# TODO(Max191): Remove this function in favor of `benchmark`. -def benchmark_models( - args: argparse.Namespace, - path_config: PathConfig, - model_candidates: list[int], - candidate_trackers: list[CandidateTracker], - tuning_client: TuningClient, -): - """Benchmark U-Net candidate files and log the results.""" - logging.debug("benchmark_models()") - - if args.dry_run: - candidate_results, baseline_results = generate_dryrun_model_benchmark_results( - model_candidates - ) - else: - # Benchmarking model candidates - worker_context_queue = create_worker_context_queue(args.devices) - benchmark_task_list = [ - TaskPack( - RunPack( - command=tuning_client.get_model_benchmark_command( - candidate_trackers[i] - ), - check=False, - timeout_seconds=tuning_client.get_dispatch_benchmark_timeout_s(), - ), - candidate_id=i, - command_need_device_id=True, - cooling_time=10, - ) - for i in model_candidates - ] - candidate_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=benchmark_task_list, - function=run_command_wrapper, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), - ) - - # Benchmarking baselines on each involved device - candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb - worker_context_queue = create_worker_context_queue(args.devices) - baseline_task_list = [ - TaskPack( - RunPack( - command=tuning_client.get_model_benchmark_command( - candidate_trackers[0] - ), - check=False, - timeout_seconds=tuning_client.get_model_benchmark_timeout_s(), - ), - candidate_id=0, - command_need_device_id=True, - ) - ] * len(group_benchmark_results_by_device_id(candidate_results)) - baseline_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=baseline_task_list, - function=run_command_wrapper, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), - ) - - dump_list = parse_model_benchmark_results( - candidate_trackers, candidate_results, baseline_results - ) - - append_to_file( - dump_list, filepath=path_config.output_unilog, title="Model Benchmark Results" - ) - - -def summarize_top_candidates( - path_config: PathConfig, candidate_trackers: list[CandidateTracker] -): - dump_list = [] - top_candidates = [] - for candidate in candidate_trackers: - if candidate.candidate_id == 0 or candidate.model_benchmark_time is None: - continue - top_candidates.append( - (candidate.candidate_id, candidate.model_benchmark_time) - ) # collect (id, time) - - top_candidates = sorted( - top_candidates, key=lambda x: x[1] - ) # sort the list in ascending benchmark time order - top_candidate_ids = [item[0] for item in top_candidates] # get list of candidate id - - for candidate_id in top_candidate_ids: - candidate = candidate_trackers[candidate_id] - assert candidate.dispatch_config_path is not None - with open(candidate.dispatch_config_path, "r") as file: - config_file_contents = file.read() - final_str = f"Candidate {candidate.candidate_id}:\nModel benchmark time: {candidate.model_benchmark_time} on device {candidate.model_benchmark_device_id}\nDispatch benchmark time: {candidate.first_benchmark_time} on device {candidate.model_benchmark_device_id}\nSpec file path: {candidate.spec_path}\nSpec contents:{config_file_contents}\n\n" - dump_list.append(final_str) - - with open(path_config.result_summary_log, "w") as file: - file.writelines(dump_list) - - -def sanitize_filename(filename: str) -> str: - # Replace invalid characters by an underscore - sanitized = re.sub(r"[^\w\.-]", "_", filename) - return sanitized diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 1b659268b..767a6aff4 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -16,31 +16,6 @@ """ -def test_group_benchmark_results_by_device_id() -> None: - # Create mock TaskResult objects with device_id attributes - task_result_1: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) - task_result_1.device_id = "device_1" - - task_result_2: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) - task_result_2.device_id = "device_2" - - task_result_3: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) - task_result_3.device_id = "device_1" - - benchmark_results = [task_result_1, task_result_2, task_result_3] - - expected_grouped_results = [ - [task_result_1, task_result_3], # Grouped by device_1 - [task_result_2], # Grouped by device_2 - ] - - grouped_results = libtuner.group_benchmark_results_by_device_id(benchmark_results) - - assert grouped_results == expected_grouped_results - assert grouped_results[0][0].device_id == "device_1" - assert grouped_results[1][0].device_id == "device_2" - - def test_find_collisions() -> None: input = [(1, "abc"), (2, "def"), (3, "abc")] assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) @@ -58,307 +33,6 @@ def test_collision_handler() -> None: assert libtuner.collision_handler(input) == (False, []) -def test_IREEBenchmarkResult_get() -> None: - # Time is int in us - int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}] - - res = libtuner.IREEBenchmarkResult(candidate_id=1, result_json=int_json) - assert res.get_mean_time_us() == float(1) - - # Time is float in us - float_json = [{"aggregate_name": "mean", "real_time": 123.45, "time_unit": "us"}] - - res = libtuner.IREEBenchmarkResult(candidate_id=2, result_json=float_json) - assert res.get_mean_time_us() == 123.45 - - # Time is in seconds - seconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "s"}] - - res = libtuner.IREEBenchmarkResult(candidate_id=3, result_json=seconds_json) - assert res.get_mean_time_us() == 1.0 * 1e6 - - # Time is in miliseconds - miliseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ms"}] - - res = libtuner.IREEBenchmarkResult(candidate_id=4, result_json=miliseconds_json) - assert res.get_mean_time_us() == 1.0 * 1e3 - - # Time is in nanoseconds - nanoseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ns"}] - - res = libtuner.IREEBenchmarkResult(candidate_id=5, result_json=nanoseconds_json) - assert res.get_mean_time_us() == 1.0 * 1e-3 - - small_number_json = [ - { - "aggregate_name": "mean", - "real_time": 3.4591828516259519e-02, - "time_unit": "ms", - } - ] - - res = libtuner.IREEBenchmarkResult(candidate_id=6, result_json=small_number_json) - assert res.get_mean_time_us() == 34.591828516259519 - - # Invalid json: missing real_time - invalid_real_time_json = [{"aggregate_name": "mean", "real_time": None}] - - res = libtuner.IREEBenchmarkResult( - candidate_id=7, result_json=invalid_real_time_json - ) - assert res.get_mean_time_us() == None - - # Invalid json: empty dictionary - res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json=[]) - assert res.get_mean_time_us() is None - - # Invalid json: invalid time unit - invalid_time_unit_json = [ - {"aggregate_name": "mean", "real_time": 1.0, "time_unit": "invalid_unit"} - ] - - with pytest.raises(AssertionError, match="Unsupported time unit: invalid_unit"): - res = libtuner.IREEBenchmarkResult( - candidate_id=9, result_json=invalid_time_unit_json - ) - res.get_mean_time_us() - - # Invalid json: missing aggregate_name - invalid_aggregate_name_json = [{"real_time": 1.0, "time_unit": "us"}] - - res = libtuner.IREEBenchmarkResult( - candidate_id=10, result_json=invalid_aggregate_name_json - ) - assert res.get_mean_time_us() is None - - -def test_generate_display_BR() -> None: - output = libtuner.generate_display_DBR(1, 3.14) - expected = f"1\tMean Time: 3.1" - assert output == expected, "DispatchBenchmarkResult generates invalid sample string" - - output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89) - expected = "Benchmarking: baseline.vmfb on device 1: 568" - assert output == expected, "ModelBenchmarkResult generates invalid sample string" - output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, 0.0314) - expected = "Benchmarking: baseline.vmfb on device 1: 568 (+3.140%)" - assert output == expected, "ModelBenchmarkResult generates invalid sample string" - output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, -3.14) - expected = "Benchmarking: baseline.vmfb on device 1: 568 (-314.000%)" - assert output == expected, "ModelBenchmarkResult generates invalid sample string" - - -def make_mock_task_result() -> libtuner.TaskResult: - process: CompletedProcess = MagicMock(spec=CompletedProcess) - run_result = libtuner.RunResult(process, False) - task_result = libtuner.TaskResult(run_result, 0, "") - return task_result - - -def test_parse_dispatch_benchmark_results() -> None: - base_path = libtuner.Path("/mock/base/dir") - spec_dir = base_path / "specs" - path_config = libtuner.PathConfig() - object.__setattr__(path_config, "specs_dir", spec_dir) - - mock_result_1 = make_mock_task_result() - mock_json_1 = { - "benchmarks": [ - {"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"} - ] - } - assert mock_result_1.run_result.process_res is not None - mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1) - mock_result_1.candidate_id = 1 - mock_result_2 = make_mock_task_result() - mock_json_2 = { - "benchmarks": [ - {"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"} - ] - } - assert mock_result_2.run_result.process_res is not None - mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2) - mock_result_2.candidate_id = 2 - mock_result_3 = make_mock_task_result() - mock_json_3 = { - "benchmarks": [ - { - "aggregate_name": "mean", - "real_time": 3.4591828516259519e-02, - "time_unit": "ms", - } - ] - } - assert mock_result_3.run_result.process_res is not None - mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3) - mock_result_3.candidate_id = 3 - # Incomplete result. - mock_result_4 = libtuner.TaskResult(libtuner.RunResult(None, True), 4, "4") - benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4] - - candidate_trackers = [] - for i in range(4): - tracker = libtuner.CandidateTracker(candidate_id=i) - tracker.dispatch_mlir_path = libtuner.Path(f"/mock/mlir/path/{i}.mlir") - candidate_trackers.append(tracker) - - expected_parsed_results = [ - libtuner.ParsedDisptachBenchmarkResult( - candidate_id=1, - benchmark_time_in_seconds=100.0, - candidate_mlir=libtuner.Path("/mock/mlir/path/1.mlir"), - candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/1_spec.mlir"), - ), - libtuner.ParsedDisptachBenchmarkResult( - candidate_id=2, - benchmark_time_in_seconds=200.0, - candidate_mlir=libtuner.Path("/mock/mlir/path/2.mlir"), - candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/2_spec.mlir"), - ), - libtuner.ParsedDisptachBenchmarkResult( - candidate_id=3, - benchmark_time_in_seconds=34.591828516259519, - candidate_mlir=libtuner.Path("/mock/mlir/path/3.mlir"), - candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/3_spec.mlir"), - ), - ] - expected_dump_list = [ - "1\tMean Time: 100.0\n", - "2\tMean Time: 200.0\n", - "3\tMean Time: 34.6\n", - "Candidate 4 not completed", - ] - - parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( - path_config, benchmark_results, candidate_trackers - ) - - assert parsed_results == expected_parsed_results - assert dump_list == expected_dump_list - assert candidate_trackers[1].first_benchmark_time == 100.0 - assert candidate_trackers[1].spec_path == libtuner.Path( - "/mock/base/dir/specs/1_spec.mlir" - ) - assert candidate_trackers[2].first_benchmark_time == 200.0 - assert candidate_trackers[2].spec_path == libtuner.Path( - "/mock/base/dir/specs/2_spec.mlir" - ) - assert candidate_trackers[3].first_benchmark_time == 34.591828516259519 - assert candidate_trackers[3].spec_path == libtuner.Path( - "/mock/base/dir/specs/3_spec.mlir" - ) - - -def test_parse_model_benchmark_results() -> None: - # Setup mock data for candidate_trackers - tracker0 = libtuner.CandidateTracker(0) - tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb") - - tracker1 = libtuner.CandidateTracker(1) - tracker1.compiled_model_path = libtuner.Path("/path/to/model_1.vmfb") - - tracker2 = libtuner.CandidateTracker(2) - tracker2.compiled_model_path = libtuner.Path("/path/to/model_2.vmfb") - - tracker3 = libtuner.CandidateTracker(3) - tracker3.compiled_model_path = libtuner.Path("/path/to/model_3.vmfb") - - candidate_trackers = [tracker0, tracker1, tracker2, tracker3] - - # Setup mock data for task results - result1 = make_mock_task_result() - result_json_1 = {"benchmarks": [{"real_time": 1.23}]} - assert result1.run_result.process_res is not None - result1.run_result.process_res.stdout = json.dumps(result_json_1) - result1.candidate_id = 1 - result1.device_id = "device1" - - result2 = make_mock_task_result() - result_json_2 = {"benchmarks": [{"real_time": 4.56}]} - assert result2.run_result.process_res is not None - result2.run_result.process_res.stdout = json.dumps(result_json_2) - result2.candidate_id = 2 - result2.device_id = "device2" - - result3 = make_mock_task_result() - result_json_3 = {"benchmarks": [{"real_time": 0.98}]} - assert result3.run_result.process_res is not None - result3.run_result.process_res.stdout = json.dumps(result_json_3) - result3.candidate_id = 0 - result3.device_id = "device1" - - result4 = make_mock_task_result() - result_json_4 = {"benchmarks": [{"real_time": 4.13}]} - assert result4.run_result.process_res is not None - result4.run_result.process_res.stdout = json.dumps(result_json_4) - result4.candidate_id = 0 - result4.device_id = "device2" - - # Incomplete baseline on device3 - result5 = libtuner.TaskResult(libtuner.RunResult(None, True), 0, "device3") - - result6 = make_mock_task_result() - result_json_6 = {"benchmarks": [{"real_time": 3.38}]} - assert result6.run_result.process_res is not None - result6.run_result.process_res.stdout = json.dumps(result_json_6) - result6.candidate_id = 3 - result6.device_id = "device3" - - candidate_results = [result1, result2, result6] - baseline_results = [result3, result4, result5] - - # Skip real benchmark extraction, directly use given values from above - def mock_get_mean_time_us(self): - return float(self.result_json[0]["real_time"]) if self.result_json else None - - # Mock IREEBenchmarkResult to return wanted benchmark times - with patch( - f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time_us", - new=mock_get_mean_time_us, - ): - # Mock handle_error to avoid actual logging during tests - with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: - dump_list = libtuner.parse_model_benchmark_results( - candidate_trackers, candidate_results, baseline_results - ) - - # Verify interactions with candidate_trackers - assert tracker1.model_benchmark_time == 1.23 - assert tracker1.model_benchmark_device_id == "device1" - assert tracker1.baseline_benchmark_time == 0.98 - assert tracker1.calibrated_benchmark_diff == pytest.approx( - (1.23 - 0.98) / 0.98, rel=1e-6 - ) - - assert tracker2.model_benchmark_time == 4.56 - assert tracker2.model_benchmark_device_id == "device2" - assert tracker2.baseline_benchmark_time == 4.13 - assert tracker2.calibrated_benchmark_diff == pytest.approx( - (4.56 - 4.13) / 4.13, rel=1e-6 - ) - - assert tracker3.model_benchmark_time == 3.38 - assert tracker3.model_benchmark_device_id == "device3" - - assert dump_list == [ - "Benchmarking: /path/to/baseline.vmfb on device device1: 0.98\n" "\n", - "Benchmarking: /path/to/model_1.vmfb on device device1: 1.23 (+25.510%)\n" - "\n", - "Benchmarking: /path/to/baseline.vmfb on device device2: 4.13\n" "\n", - "Benchmarking: /path/to/model_2.vmfb on device device2: 4.56 (+10.412%)\n" - "\n", - "Benchmarking: /path/to/model_3.vmfb on device device3: 3.38\n" "\n", - "Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete\n", - ] - - # Verify handle_error was called correctly - mock_handle_error.assert_called_once_with( - condition=True, - msg="Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete", - level=libtuner.logging.WARNING, - ) - - def test_extract_driver_names() -> None: user_devices = ["hip://0", "local-sync://default", "cuda://default"] expected_output = {"hip", "local-sync", "cuda"}