diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 26403c4a9..1fbc616ff 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -801,21 +801,23 @@ def benchmark_baseline( class BaselineResultHandler: def __init__(self) -> None: - # Maps device IDs to a list of baseline run times (in milliseconds). - self.device_baseline_times: dict[str, list[float]] = defaultdict(list) + # Maps device IDs to a list of `BenchmarkResult`. + self.device_baseline_results: dict[str, list[BenchmarkResult]] = defaultdict( + list + ) def add_run(self, results: list[BenchmarkResult]) -> None: for result in results: - self.device_baseline_times[result.device_id].append(result.time) + self.device_baseline_results[result.device_id].append(result) def are_baseline_devices_unique(self, results: list[BenchmarkResult]) -> bool: - return len(results) == len(set(map(lambda r: r.device_id, results))) + return len(results) == len(set(result.device_id for result in results)) def get_valid_time_ms(self, device_id: str) -> list[float]: return [ - time - for time in self.device_baseline_times.get(device_id, []) - if math.isfinite(time) + result.time + for result in self.device_baseline_results.get(device_id, []) + if math.isfinite(result.time) ] def num_successful_runs(self, device_id: str) -> int: @@ -859,7 +861,7 @@ def is_valid(self) -> bool: """ return any( self.get_valid_time_ms(device_id) - for device_id in self.device_baseline_times + for device_id in self.device_baseline_results ) def is_valid_for_device(self, device_id: str) -> bool: @@ -882,9 +884,10 @@ def calculate_speedup( # Calculate the fallback baseline as the average of all valid times across devices valid_baseline_times = [ - time - for device_id in self.device_baseline_times - for time in self.get_valid_time_ms(device_id) + result.time + for device_id in self.device_baseline_results + for result in self.device_baseline_results[device_id] + if math.isfinite(result.time) ] fallback_baseline = sum(valid_baseline_times) / len(valid_baseline_times) diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 572dac1e2..b3b1ad575 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -207,8 +207,13 @@ def test_baseline_result_handler_valid(): assert handler.is_valid_for_device("hip://0") assert not handler.is_valid_for_device("hip://1") - assert handler.device_baseline_times["hip://0"] == [0.5, 0.7] - assert handler.device_baseline_times["hip://1"] == [math.inf] + assert handler.device_baseline_results["hip://0"] == [ + libtuner.BenchmarkResult(0, 0.5, "hip://0"), + libtuner.BenchmarkResult(0, 0.7, "hip://0"), + ] + assert handler.device_baseline_results["hip://1"] == [ + libtuner.BenchmarkResult(0, math.inf, "hip://1"), + ] assert handler.num_successful_runs("hip://0") == 2 assert handler.num_successful_runs("hip://1") == 0