Skip to content

Commit

Permalink
handle error cases
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Jan 9, 2025
1 parent f2ad14e commit e9e16a3
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 25 deletions.
93 changes: 68 additions & 25 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import math
import random
import signal
import sys
import shutil
Expand Down Expand Up @@ -816,6 +817,63 @@ def compile(
return compiled_candidates


def select_best_benchmark_results(
candidate_results: list[BenchmarkResult],
baseline_results: list[BenchmarkResult],
num_candidates: Optional[int],
) -> list[BenchmarkResult]:
filtered_candidate_results = [r for r in candidate_results if math.isfinite(r.time)]
if len(filtered_candidate_results) == 0:
logging.error("No successful candidate benchmarks.")
return []
fallback_baseline_time: Optional[float] = None
filtered_baseline_results: list[BenchmarkResult] = []
for r in baseline_results:
if math.isfinite(r.time):
filtered_baseline_results.append(r)
fallback_baseline_time = r.time
else:
logging.warning(f"Baseline on device {r.device_id} failed.")
if fallback_baseline_time is None:
logging.warning(
f"All baseline benchmarks failed. Baselines will not be used to select top candidates"
)
baseline_times_by_device = {}
for r in filtered_baseline_results:
baseline_times_by_device[r.device_id] = r.time

# Select top candidates
def get_speedup(result: BenchmarkResult) -> float:
if result.device_id in baseline_times_by_device:
return result.time / baseline_times_by_device[result.device_id]
assert fallback_baseline_time is not None, "expected fallback_baseline_time"
return result.time / fallback_baseline_time

num_top_candidates = len(filtered_candidate_results)
if num_candidates is not None:
num_top_candidates = num_candidates

# Sort by the speedup over baseline on the same device. If a device failed
# the baseline benchmark, then use the fallback baseline. If there is no
# successful baseline, then the best we can do is to sort by the actual
# time.
sorting_key = get_speedup
if fallback_baseline_time is None:
sorting_key = lambda result: result.time
best_results = sorted(filtered_candidate_results, key=sorting_key)[
:num_top_candidates
]
logging.info(f"Selected top[{len(best_results)}]:")

for r in best_results:
if fallback_baseline_time is not None:
speedup = f"{round(get_speedup(r) * 100, 2)}% of baseline"
else:
speedup = "baseline unavailable"
logging.info(f"Candidate {r.candidate_id} time: {r.time} ({speedup})")
return best_results


def benchmark(
args: argparse.Namespace,
path_config: PathConfig,
Expand All @@ -825,6 +883,9 @@ def benchmark(
num_candidates: Optional[int] = None,
):
logging.debug("benchmark()")
if len(compiled_candidates) == 0:
logging.warning("No candidates to benchmark.")
return []

task_list = [
BenchmarkPack(
Expand All @@ -836,14 +897,13 @@ def benchmark(
if i != 0
]
worker_context_queue = create_worker_context_queue(args.devices)
candidate_results = multiprocess_progress_wrapper(
candidate_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
num_worker=len(args.devices),
task_list=task_list,
function=run_iree_benchmark_module_command,
initializer=init_worker_context,
initializer_inputs=(worker_context_queue,),
)
candidate_results = [r for r in candidate_results if math.isfinite(r.time)]

# Benchmarking baselines on each involved device.
worker_context_queue = create_worker_context_queue(args.devices)
Expand All @@ -854,36 +914,19 @@ def benchmark(
candidate_tracker=candidate_trackers[0],
)
] * len(args.devices)
baseline_results = multiprocess_progress_wrapper(
baseline_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
num_worker=len(args.devices),
task_list=baseline_task_list,
function=run_iree_benchmark_module_command,
initializer=init_worker_context,
initializer_inputs=(worker_context_queue,),
)
baseline_results = [r for r in baseline_results if math.isfinite(r.time)]
assert len(baseline_results) == len(
args.devices
), "baseline benchmarks should not fail"
baseline_times_by_device = {}
for r in baseline_results:
baseline_times_by_device[r.device_id] = r.time

# Select top candidates
def get_speedup(result: BenchmarkResult) -> float:
return result.time / baseline_times_by_device[result.device_id]

num_top_candidates = len(candidate_results)
if num_candidates is not None:
num_top_candidates = num_candidates
best_results = sorted(candidate_results, key=get_speedup)[:num_top_candidates]
logging.info(f"Selected top[{len(best_results)}]:")

for r in best_results:
speedup = round(get_speedup(r) * 100, 2)
logging.info(
f"Candidate {r.candidate_id} time: {r.time} ({speedup}% of baseline)"
)
best_results: list[BenchmarkResult] = select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=num_candidates,
)

top_candidates = [result.candidate_id for result in best_results]
return top_candidates
56 changes: 56 additions & 0 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import math
import pytest
import json
from subprocess import CompletedProcess
Expand Down Expand Up @@ -175,5 +176,60 @@ def test_validate_devices_with_invalid_device() -> None:
assert expected_call in mock_handle_error.call_args_list


def test_select_best_benchmark_results() -> None:
candidate_results = [
libtuner.BenchmarkResult(1, 0.5, "hip://0"),
libtuner.BenchmarkResult(2, 0.3, "hip://1"),
libtuner.BenchmarkResult(3, 0.2, "hip://2"),
libtuner.BenchmarkResult(4, 0.1, "hip://3"),
]
baseline_results = [
libtuner.BenchmarkResult(0, 1.0, "hip://0"),
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
libtuner.BenchmarkResult(0, 0.1, "hip://2"),
libtuner.BenchmarkResult(0, 0.1, "hip://3"),
]
best_results: list[
libtuner.BenchmarkResult
] = libtuner.select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=3,
)
assert best_results[0].candidate_id == 1
assert best_results[1].candidate_id == 4
assert best_results[2].candidate_id == 3

baseline_results = [
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
libtuner.BenchmarkResult(0, 0.1, "hip://2"),
libtuner.BenchmarkResult(0, 0.1, "hip://3"),
]
best_results = libtuner.select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=3,
)
assert best_results[0].candidate_id == 4
assert best_results[1].candidate_id == 3
assert best_results[2].candidate_id == 2

baseline_results = [
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
libtuner.BenchmarkResult(0, math.inf, "hip://2"),
libtuner.BenchmarkResult(0, math.inf, "hip://3"),
]
best_results = libtuner.select_best_benchmark_results(
candidate_results=candidate_results,
baseline_results=baseline_results,
num_candidates=3,
)
assert best_results[0].candidate_id == 4
assert best_results[1].candidate_id == 3
assert best_results[2].candidate_id == 2


def test_enum_collision():
from iree.compiler.dialects import linalg, vector, iree_gpu, iree_codegen, iree_input # type: ignore

0 comments on commit e9e16a3

Please sign in to comment.