Skip to content

Commit

Permalink
[tuner] Clean up simple tuner example script (#785)
Browse files Browse the repository at this point in the history
This PR addresses the TODO in `tuner/examples/simple/simple_tuner.py` to
remove the unused abstract function implementations in the tuner client.
The PR also renames some variables to be more consistent with the name
of the example.

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored and eagarvey-amd committed Jan 8, 2025
1 parent 63057e2 commit fc79ce9
Showing 1 changed file with 16 additions and 51 deletions.
67 changes: 16 additions & 51 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tuner.common import *


class TestTuner(libtuner.TuningClient):
class SimpleTuner(libtuner.TuningClient):
def __init__(self, tuner_context: libtuner.TunerContext):
super().__init__(tuner_context)
self.compile_flags = ["--compile-from=executable-sources"]
Expand All @@ -25,62 +25,27 @@ def get_iree_benchmark_module_flags(self) -> list[str]:
def get_benchmark_timeout_s(self) -> int:
return 10

# TODO(Max191): Remove the following unused abstract functions once they
# are removed from the TuningClient definition.
def get_dispatch_benchmark_timeout_s(self) -> int:
return 0

def get_dispatch_compile_timeout_s(self) -> int:
return 0

def get_dispatch_compile_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
return []

def get_dispatch_benchmark_command(
self,
candidate_tracker: libtuner.CandidateTracker,
) -> list[str]:
return []

def get_model_compile_timeout_s(self) -> int:
return 0

def get_model_compile_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
return []

def get_model_benchmark_timeout_s(self) -> int:
return 0

def get_model_benchmark_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
return []


def main():
# Custom arguments for the test file.
parser = argparse.ArgumentParser(description="Autotune test script")
test_args = parser.add_argument_group("Example Test Options")
test_args.add_argument(
# Custom arguments for the example tuner file.
parser = argparse.ArgumentParser(description="Autotune sample script")
client_args = parser.add_argument_group("Simple Example Tuner Options")
client_args.add_argument(
"simple_model_file", type=Path, help="Path to the model file to tune (.mlir)"
)
test_args.add_argument(
client_args.add_argument(
"--simple-num-dispatch-candidates",
type=int,
default=None,
help="Number of dispatch candidates to keep for model benchmarks.",
)
test_args.add_argument(
client_args.add_argument(
"--simple-num-model-candidates",
type=int,
default=None,
help="Number of model candidates to produce after tuning.",
)
test_args.add_argument(
client_args.add_argument(
"--simple-hip-target",
type=str,
default="gfx942",
Expand All @@ -106,17 +71,17 @@ def main():

print("Generating candidate tuning specs...")
with TunerContext() as tuner_context:
test_tuner = TestTuner(tuner_context)
simple_tuner = SimpleTuner(tuner_context)
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
args, path_config, candidate_trackers, simple_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
args, path_config, candidates, candidate_trackers, simple_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return
Expand All @@ -127,14 +92,14 @@ def main():
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
simple_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
simple_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
Expand All @@ -143,14 +108,14 @@ def main():
path_config,
top_candidates,
candidate_trackers,
test_tuner,
simple_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
simple_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
Expand All @@ -160,7 +125,7 @@ def main():
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
simple_tuner,
args.simple_num_model_candidates,
)

Expand Down

0 comments on commit fc79ce9

Please sign in to comment.