Skip to content

Commit

Permalink
update the handler class
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 17, 2025
1 parent cfcb32a commit 2b6a2be
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
25 changes: 14 additions & 11 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,21 +801,23 @@ def benchmark_baseline(

class BaselineResultHandler:
def __init__(self) -> None:
# Maps device IDs to a list of baseline run times (in milliseconds).
self.device_baseline_times: dict[str, list[float]] = defaultdict(list)
# Maps device IDs to a list of `BenchmarkResult`.
self.device_baseline_results: dict[str, list[BenchmarkResult]] = defaultdict(
list
)

def add_run(self, results: list[BenchmarkResult]) -> None:
for result in results:
self.device_baseline_times[result.device_id].append(result.time)
self.device_baseline_results[result.device_id].append(result)

def are_baseline_devices_unique(self, results: list[BenchmarkResult]) -> bool:
return len(results) == len(set(map(lambda r: r.device_id, results)))
return len(results) == len(set(result.device_id for result in results))

def get_valid_time_ms(self, device_id: str) -> list[float]:
return [
time
for time in self.device_baseline_times.get(device_id, [])
if math.isfinite(time)
result.time
for result in self.device_baseline_results.get(device_id, [])
if math.isfinite(result.time)
]

def num_successful_runs(self, device_id: str) -> int:
Expand Down Expand Up @@ -859,7 +861,7 @@ def is_valid(self) -> bool:
"""
return any(
self.get_valid_time_ms(device_id)
for device_id in self.device_baseline_times
for device_id in self.device_baseline_results
)

def is_valid_for_device(self, device_id: str) -> bool:
Expand All @@ -882,9 +884,10 @@ def calculate_speedup(

# Calculate the fallback baseline as the average of all valid times across devices
valid_baseline_times = [
time
for device_id in self.device_baseline_times
for time in self.get_valid_time_ms(device_id)
result.time
for device_id in self.device_baseline_results
for result in self.device_baseline_results[device_id]
if math.isfinite(result.time)
]

fallback_baseline = sum(valid_baseline_times) / len(valid_baseline_times)
Expand Down
9 changes: 7 additions & 2 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,13 @@ def test_baseline_result_handler_valid():
assert handler.is_valid_for_device("hip://0")
assert not handler.is_valid_for_device("hip://1")

assert handler.device_baseline_times["hip://0"] == [0.5, 0.7]
assert handler.device_baseline_times["hip://1"] == [math.inf]
assert handler.device_baseline_results["hip://0"] == [
libtuner.BenchmarkResult(0, 0.5, "hip://0"),
libtuner.BenchmarkResult(0, 0.7, "hip://0"),
]
assert handler.device_baseline_results["hip://1"] == [
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
]

assert handler.num_successful_runs("hip://0") == 2
assert handler.num_successful_runs("hip://1") == 0
Expand Down

0 comments on commit 2b6a2be

Please sign in to comment.