Skip to content

Commit

Permalink
[tuner] Filter out non finite benchmark times (#799)
Browse files Browse the repository at this point in the history
This PR fixes a bug where math.inf benchmark times can be selected as
the top candidates. Any non finite times are now filtered out before
selecting top candidates.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Jan 9, 2025
1 parent 500df58 commit 35ad7d0
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 22 deletions.
89 changes: 67 additions & 22 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,7 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
)

times = []
logging.debug(f"candidate {candidate_id} benchmark_results: {benchmark_results}")
for benchmark_result in benchmark_results:
logging.debug(f"candidate {candidate_id} benchmark_result: {benchmark_result}")
benchmark_name = benchmark_result.benchmark_name
# With multiple benchmark results, there will be `real_time_mean`, but
# not with single iteration benchmark results, so ignore the mean time
Expand Down Expand Up @@ -818,6 +816,63 @@ def compile(
return compiled_candidates


def select_best_benchmark_results(
candidate_results: list[BenchmarkResult],
baseline_results: list[BenchmarkResult],
num_candidates: Optional[int],
) -> list[BenchmarkResult]:
filtered_candidate_results = [r for r in candidate_results if math.isfinite(r.time)]
if len(filtered_candidate_results) == 0:
logging.error("No successful candidate benchmarks.")
return []
fallback_baseline_time: Optional[float] = None
filtered_baseline_results: list[BenchmarkResult] = []
for r in baseline_results:
if math.isfinite(r.time):
filtered_baseline_results.append(r)
fallback_baseline_time = r.time
else:
logging.warning(f"Baseline on device {r.device_id} failed.")
if fallback_baseline_time is None:
logging.warning(
f"All baseline benchmarks failed. Baselines will not be used to select top candidates"
)
baseline_times_by_device = {}
for r in filtered_baseline_results:
baseline_times_by_device[r.device_id] = r.time

# Select top candidates
def get_speedup(result: BenchmarkResult) -> float:
if result.device_id in baseline_times_by_device:
return result.time / baseline_times_by_device[result.device_id]
assert fallback_baseline_time is not None, "expected fallback_baseline_time"
return result.time / fallback_baseline_time

num_top_candidates = len(filtered_candidate_results)
if num_candidates is not None:
num_top_candidates = num_candidates

# Sort by the speedup over baseline on the same device. If a device failed
# the baseline benchmark, then use the fallback baseline. If there is no
# successful baseline, then the best we can do is to sort by the actual
# time.
sorting_key = get_speedup
if fallback_baseline_time is None:
sorting_key = lambda result: result.time
best_results = sorted(filtered_candidate_results, key=sorting_key)[
:num_top_candidates
]
logging.info(f"Selected top[{len(best_results)}]:")

for r in best_results:
if fallback_baseline_time is not None:
speedup = f"{round(get_speedup(r) * 100, 2)}% of baseline"
else:
speedup = "baseline unavailable"
logging.info(f"Candidate {r.candidate_id} time: {r.time} ({speedup})")
return best_results


def benchmark(
args: argparse.Namespace,
path_config: PathConfig,
Expand All @@ -827,6 +882,9 @@ def benchmark(
num_candidates: Optional[int] = None,
):
logging.debug("benchmark()")
if len(compiled_candidates) == 0:
logging.warning("No candidates to benchmark.")
return []

task_list = [
BenchmarkPack(
Expand All @@ -838,7 +896,7 @@ def benchmark(
if i != 0
]
worker_context_queue = create_worker_context_queue(args.devices)
candidate_results = multiprocess_progress_wrapper(
candidate_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
num_worker=len(args.devices),
task_list=task_list,
function=run_iree_benchmark_module_command,
Expand All @@ -855,32 +913,19 @@ def benchmark(
candidate_tracker=candidate_trackers[0],
)
] * len(args.devices)
baseline_results = multiprocess_progress_wrapper(
baseline_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
num_worker=len(args.devices),
task_list=baseline_task_list,
function=run_iree_benchmark_module_command,
initializer=init_worker_context,
initializer_inputs=(worker_context_queue,),
)
baseline_times_by_device = {}
for r in baseline_results:
baseline_times_by_device[r.device_id] = r.time

# Select top candidates
def get_speedup(result: BenchmarkResult) -> float:
return result.time / baseline_times_by_device[result.device_id]

num_top_candidates = len(candidate_results)
if num_candidates is not None:
num_top_candidates = num_candidates
best_results = sorted(candidate_results, key=get_speedup)[:num_top_candidates]
logging.info(f"Selected top[{len(best_results)}]:")

for r in best_results:
speedup = round(get_speedup(r) * 100, 2)
logging.info(
f"Candidate {r.candidate_id} time: {r.time} ({speedup}% of baseline)"
)
best_results: list[BenchmarkResult] = select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=num_candidates,
)

top_candidates = [result.candidate_id for result in best_results]
return top_candidates
56 changes: 56 additions & 0 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import math
import pytest
import json
from subprocess import CompletedProcess
Expand Down Expand Up @@ -175,5 +176,60 @@ def test_validate_devices_with_invalid_device() -> None:
assert expected_call in mock_handle_error.call_args_list


def test_select_best_benchmark_results() -> None:
candidate_results = [
libtuner.BenchmarkResult(1, 0.5, "hip://0"),
libtuner.BenchmarkResult(2, 0.3, "hip://1"),
libtuner.BenchmarkResult(3, 0.2, "hip://2"),
libtuner.BenchmarkResult(4, 0.1, "hip://3"),
]
baseline_results = [
libtuner.BenchmarkResult(0, 1.0, "hip://0"),
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
libtuner.BenchmarkResult(0, 0.1, "hip://2"),
libtuner.BenchmarkResult(0, 0.1, "hip://3"),
]
best_results: list[
libtuner.BenchmarkResult
] = libtuner.select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=3,
)
assert best_results[0].candidate_id == 1
assert best_results[1].candidate_id == 4
assert best_results[2].candidate_id == 3

baseline_results = [
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
libtuner.BenchmarkResult(0, 0.1, "hip://2"),
libtuner.BenchmarkResult(0, 0.1, "hip://3"),
]
best_results = libtuner.select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=3,
)
assert best_results[0].candidate_id == 4
assert best_results[1].candidate_id == 3
assert best_results[2].candidate_id == 2

baseline_results = [
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
libtuner.BenchmarkResult(0, math.inf, "hip://2"),
libtuner.BenchmarkResult(0, math.inf, "hip://3"),
]
best_results = libtuner.select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=3,
)
assert best_results[0].candidate_id == 4
assert best_results[1].candidate_id == 3
assert best_results[2].candidate_id == 2


def test_enum_collision():
from iree.compiler.dialects import linalg, vector, iree_gpu, iree_codegen, iree_input # type: ignore

0 comments on commit 35ad7d0

Please sign in to comment.