diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index acc01176a..969979393 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -17,6 +17,7 @@ import math +import random import signal import sys import shutil @@ -816,6 +817,63 @@ def compile( return compiled_candidates +def select_best_benchmark_results( + candidate_results: list[BenchmarkResult], + baseline_results: list[BenchmarkResult], + num_candidates: Optional[int], +) -> list[BenchmarkResult]: + filtered_candidate_results = [r for r in candidate_results if math.isfinite(r.time)] + if len(filtered_candidate_results) == 0: + logging.error("No successful candidate benchmarks.") + return [] + fallback_baseline_time: Optional[float] = None + filtered_baseline_results: list[BenchmarkResult] = [] + for r in baseline_results: + if math.isfinite(r.time): + filtered_baseline_results.append(r) + fallback_baseline_time = r.time + else: + logging.warning(f"Baseline on device {r.device_id} failed.") + if fallback_baseline_time is None: + logging.warning( + f"All baseline benchmarks failed. Baselines will not be used to select top candidates" + ) + baseline_times_by_device = {} + for r in filtered_baseline_results: + baseline_times_by_device[r.device_id] = r.time + + # Select top candidates + def get_speedup(result: BenchmarkResult) -> float: + if result.device_id in baseline_times_by_device: + return result.time / baseline_times_by_device[result.device_id] + assert fallback_baseline_time is not None, "expected fallback_baseline_time" + return result.time / fallback_baseline_time + + num_top_candidates = len(filtered_candidate_results) + if num_candidates is not None: + num_top_candidates = num_candidates + + # Sort by the speedup over baseline on the same device. If a device failed + # the baseline benchmark, then use the fallback baseline. If there is no + # successful baseline, then the best we can do is to sort by the actual + # time. + sorting_key = get_speedup + if fallback_baseline_time is None: + sorting_key = lambda result: result.time + best_results = sorted(filtered_candidate_results, key=sorting_key)[ + :num_top_candidates + ] + logging.info(f"Selected top[{len(best_results)}]:") + + for r in best_results: + if fallback_baseline_time is not None: + speedup = f"{round(get_speedup(r) * 100, 2)}% of baseline" + else: + speedup = "baseline unavailable" + logging.info(f"Candidate {r.candidate_id} time: {r.time} ({speedup})") + return best_results + + def benchmark( args: argparse.Namespace, path_config: PathConfig, @@ -825,6 +883,9 @@ def benchmark( num_candidates: Optional[int] = None, ): logging.debug("benchmark()") + if len(compiled_candidates) == 0: + logging.warning("No candidates to benchmark.") + return [] task_list = [ BenchmarkPack( @@ -836,14 +897,13 @@ def benchmark( if i != 0 ] worker_context_queue = create_worker_context_queue(args.devices) - candidate_results = multiprocess_progress_wrapper( + candidate_results: list[BenchmarkResult] = multiprocess_progress_wrapper( num_worker=len(args.devices), task_list=task_list, function=run_iree_benchmark_module_command, initializer=init_worker_context, initializer_inputs=(worker_context_queue,), ) - candidate_results = [r for r in candidate_results if math.isfinite(r.time)] # Benchmarking baselines on each involved device. worker_context_queue = create_worker_context_queue(args.devices) @@ -854,36 +914,19 @@ def benchmark( candidate_tracker=candidate_trackers[0], ) ] * len(args.devices) - baseline_results = multiprocess_progress_wrapper( + baseline_results: list[BenchmarkResult] = multiprocess_progress_wrapper( num_worker=len(args.devices), task_list=baseline_task_list, function=run_iree_benchmark_module_command, initializer=init_worker_context, initializer_inputs=(worker_context_queue,), ) - baseline_results = [r for r in baseline_results if math.isfinite(r.time)] - assert len(baseline_results) == len( - args.devices - ), "baseline benchmarks should not fail" - baseline_times_by_device = {} - for r in baseline_results: - baseline_times_by_device[r.device_id] = r.time - # Select top candidates - def get_speedup(result: BenchmarkResult) -> float: - return result.time / baseline_times_by_device[result.device_id] - - num_top_candidates = len(candidate_results) - if num_candidates is not None: - num_top_candidates = num_candidates - best_results = sorted(candidate_results, key=get_speedup)[:num_top_candidates] - logging.info(f"Selected top[{len(best_results)}]:") - - for r in best_results: - speedup = round(get_speedup(r) * 100, 2) - logging.info( - f"Candidate {r.candidate_id} time: {r.time} ({speedup}% of baseline)" - ) + best_results: list[BenchmarkResult] = select_best_benchmark_results( + candidate_results=candidate_results, + baseline_results=baseline_results, + num_candidates=num_candidates, + ) top_candidates = [result.candidate_id for result in best_results] return top_candidates diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 767a6aff4..cad57a3cd 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse +import math import pytest import json from subprocess import CompletedProcess @@ -175,5 +176,60 @@ def test_validate_devices_with_invalid_device() -> None: assert expected_call in mock_handle_error.call_args_list +def test_select_best_benchmark_results() -> None: + candidate_results = [ + libtuner.BenchmarkResult(1, 0.5, "hip://0"), + libtuner.BenchmarkResult(2, 0.3, "hip://1"), + libtuner.BenchmarkResult(3, 0.2, "hip://2"), + libtuner.BenchmarkResult(4, 0.1, "hip://3"), + ] + baseline_results = [ + libtuner.BenchmarkResult(0, 1.0, "hip://0"), + libtuner.BenchmarkResult(0, 0.1, "hip://1"), + libtuner.BenchmarkResult(0, 0.1, "hip://2"), + libtuner.BenchmarkResult(0, 0.1, "hip://3"), + ] + best_results: list[ + libtuner.BenchmarkResult + ] = libtuner.select_best_benchmark_results( + candidate_results=candidate_results, + baseline_results=baseline_results, + num_candidates=3, + ) + assert best_results[0].candidate_id == 1 + assert best_results[1].candidate_id == 4 + assert best_results[2].candidate_id == 3 + + baseline_results = [ + libtuner.BenchmarkResult(0, math.inf, "hip://0"), + libtuner.BenchmarkResult(0, 0.1, "hip://1"), + libtuner.BenchmarkResult(0, 0.1, "hip://2"), + libtuner.BenchmarkResult(0, 0.1, "hip://3"), + ] + best_results = libtuner.select_best_benchmark_results( + candidate_results=candidate_results, + baseline_results=baseline_results, + num_candidates=3, + ) + assert best_results[0].candidate_id == 4 + assert best_results[1].candidate_id == 3 + assert best_results[2].candidate_id == 2 + + baseline_results = [ + libtuner.BenchmarkResult(0, math.inf, "hip://0"), + libtuner.BenchmarkResult(0, math.inf, "hip://1"), + libtuner.BenchmarkResult(0, math.inf, "hip://2"), + libtuner.BenchmarkResult(0, math.inf, "hip://3"), + ] + best_results = libtuner.select_best_benchmark_results( + candidate_results=candidate_results, + baseline_results=baseline_results, + num_candidates=3, + ) + assert best_results[0].candidate_id == 4 + assert best_results[1].candidate_id == 3 + assert best_results[2].candidate_id == 2 + + def test_enum_collision(): from iree.compiler.dialects import linalg, vector, iree_gpu, iree_codegen, iree_input # type: ignore