Skip to content

Commit

Permalink
[Tuner] Clean up sample tuner
Browse files Browse the repository at this point in the history
Rename it from 'test' to 'simple' to avoid mistaking it for a test:
https://github.com/nod-ai/shark-ai/actions/runs/12659028888/job/35277223146#step:6:26 .

Also:
* Update and improve the README (account for the directory structure)
* Make flag naming consistent
* Handle previously missing `--stop-after` phases
  • Loading branch information
kuhar committed Jan 8, 2025
1 parent 04b1819 commit 8fcbf8f
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 28 deletions.
2 changes: 2 additions & 0 deletions tuner/examples/simple/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tmp

26 changes: 15 additions & 11 deletions tuner/examples/test/README.md → tuner/examples/simple/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Example Tuner Test
# Simple Example Tuner

Example of tuning a dispatch and full model.
Example of tuning a dispatch and a full model.

## Environments
Follow instructions in [`/tuner/README.md`](../README.md)
Expand All @@ -15,27 +15,31 @@ Use the usual `iree-compile` command for your model, add
`--iree-hal-dump-executable-files-to=dump --iree-config-add-tuner-attributes`,
and get the dispatch benchmark that you want to tune. For example:
```shell
mkdir tmp
iree-compile double_mmt.mlir --iree-hal-target-backends=rocm \
--iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump \
--iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=tmp/dump \
--iree-config-add-tuner-attributes -o /dev/null

cp dump/module_main_dispatch_0_rocm_hsaco_fb_benchmark.mlir mmt_benchmark.mlir
cp tmp/dump/module_main_dispatch_0_rocm_hsaco_fb_benchmark.mlir tmp/mmt_benchmark.mlir
```

### Recommended Trial Run
For an initial trial to test the tuning loop, use:
```shell
python -m examples.test double_mmt.mlir mmt_benchmark.mlir \
--test_num_dispatch_candidates=5 --test_num_model_candidates=3 \
--num-candidates=30
cd ../../
python -m examples.simple examples/simple/double_mmt.mlir \
examples/simple/tmp/mmt_benchmark.mlir \
--devices=hip://0 --num-candidates=30 \
--simple-num-dispatch-candidates=5 --simple-num-model-candidates=3 \
```

### Basic Usage
```shell
python -m examples.test <model_file_path> <benchmark_file_path> \
--test_num_dispatch_candidates=<num_dispatch_candidates> \
--test_num_model_candidates=<num_model_candidates> \
--test_hip_target=<hip_target> \
python -m examples.simple <model_file_path> <benchmark_file_path> \
--devices=hip://0 --num-candidates=1024 \
--test-num-dispatch-candidates=<num_dispatch_candidates> \
--test-num-model-candidates=<num_model_candidates> \
--test-hip-target=<hip_target> \
--num-candidates=<num_generated_candidates> \
--codegen-pipeline=<codegen_pipeline>
```
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import tuner_test
from . import simple_tuner

tuner_test.main()
simple_tuner.main()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,22 @@ def main():
parser = argparse.ArgumentParser(description="Autotune test script")
test_args = parser.add_argument_group("Example Test Options")
test_args.add_argument(
"test_model_file", type=Path, help="Path to the model file to tune (.mlir)"
"simple_model_file", type=Path, help="Path to the model file to tune (.mlir)"
)
test_args.add_argument(
"--test_num_dispatch_candidates",
"--simple-num-dispatch-candidates",
type=int,
default=None,
help="Number of dispatch candidates to keep for model benchmarks.",
)
test_args.add_argument(
"--test_num_model_candidates",
"--simple-num-model-candidates",
type=int,
default=None,
help="Number of model candidates to produce after tuning.",
)
test_args.add_argument(
"--test_hip_target",
"--simple-hip-target",
type=str,
default="gfx942",
help="Hip target for tuning.",
Expand All @@ -98,51 +98,54 @@ def main():
libtuner.setup_logging(args, path_config)
print(path_config.run_log, end="\n\n")

# TODO(Max191): Some bug seems to be causing OOM errors in benchmarking
# when device validation happens, so this is commented for now. Uncomment
# when the bug is fixed.
if not args.dry_run:
print("Validating devices")
libtuner.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
print("Generating candidate tuning specs...")
test_tuner = TestTuner()
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate specs in {path_config.specs_dir}\n")
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling candidates...")
print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled candidates...")
print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.test_num_dispatch_candidates,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.test_hip_target}",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.test_model_file,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
Expand All @@ -156,7 +159,7 @@ def main():
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.test_num_model_candidates,
args.simple_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")
Expand Down

0 comments on commit 8fcbf8f

Please sign in to comment.