Skip to content

Commit

Permalink
continue to add helper functions and tests, fix mypy notes
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 9, 2025
1 parent f64c84b commit ad10da7
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 44 deletions.
5 changes: 2 additions & 3 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from abc import abstractmethod

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_codegen # type: ignore

from .common import *
Expand Down Expand Up @@ -355,7 +354,7 @@ def main():
prefetch_shared_memory=args.prefetch_shared_memory_options,
no_reduce_shared_memory_bank_conflicts=args.no_reduce_shared_memory_bank_conflicts_options,
)
specs: list[ir.Module] = generate_configs_and_td_specs(
specs = generate_configs_and_td_specs(
mlir_module,
tuner_ctx,
args.limit,
Expand All @@ -369,7 +368,7 @@ def main():
spec_path = spec_dir / f"{candidate_num}_spec.mlir"
spec_dir.mkdir(parents=True, exist_ok=True)
with open(spec_path, "w") as f:
local_scope_spec_str: str = spec.operation.get_asm(use_local_scope=True)
local_scope_spec_str = spec.operation.get_asm(use_local_scope=True)
f.write(local_scope_spec_str)


Expand Down
104 changes: 65 additions & 39 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ def validate_baselines_device_ids_match(
def validate_baseline_regression(
first_baseline_by_device: dict[str, float],
second_baseline_by_device: dict[str, float],
) -> bool:
regression_detected = False
) -> list[str]:
regression_device_ids = []

for device_id in first_baseline_by_device:
if device_id not in second_baseline_by_device:
continue
Expand All @@ -262,9 +263,9 @@ def validate_baseline_regression(
f"Baseline time = {first_baseline_time}, Post-baseline time = {second_baseline_time}, "
f"Slower by {percentage_slower:.3f}%"
)
regression_detected = True
regression_device_ids.append(device_id)

return regression_detected
return regression_device_ids


class ExecutionPhases(str, Enum):
Expand Down Expand Up @@ -785,6 +786,31 @@ def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, lis
return collision_detected, unique_indexes


def benchmark_candidates(candidate_indices, devices, tuning_client, candidate_trackers):
"""
Runs the benchmarking for a given list of candidate indices.
"""
worker_context_queue = create_worker_context_queue(devices)

task_list = [
BenchmarkPack(
iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(),
benchmark_timeout=tuning_client.get_benchmark_timeout_s(),
candidate_tracker=candidate_trackers[idx],
)
for idx in candidate_indices
]

# Perform benchmarking.
return multiprocess_progress_wrapper(
num_worker=len(devices),
task_list=task_list,
function=run_iree_benchmark_module_command,
initializer=init_worker_context,
initializer_inputs=(worker_context_queue,),
)


def compile(
args: argparse.Namespace,
path_config: PathConfig,
Expand Down Expand Up @@ -873,12 +899,13 @@ def select_best_benchmark_results(
baseline_results: list[BenchmarkResult],
num_candidates: Optional[int],
) -> list[BenchmarkResult]:
filtered_candidate_results = [r for r in candidate_results if math.isfinite(r.time)]
filtered_candidate_results = validate_benchmark_results(candidate_results)
if len(filtered_candidate_results) == 0:
logging.error("No successful candidate benchmarks.")
return []
fallback_baseline_time: Optional[float] = None
filtered_baseline_results: list[BenchmarkResult] = []
# TODO(Bangtian): use median number instead of last valid baseline result as fallback.
for r in baseline_results:
if math.isfinite(r.time):
filtered_baseline_results.append(r)
Expand All @@ -889,9 +916,10 @@ def select_best_benchmark_results(
logging.warning(
f"All baseline benchmarks failed. Baselines will not be used to select top candidates"
)
baseline_times_by_device = {}
for r in filtered_baseline_results:
baseline_times_by_device[r.device_id] = r.time

baseline_times_by_device = map_baseline_by_device(filtered_baseline_results)
# for r in filtered_baseline_results:
# baseline_times_by_device[r.device_id] = r.time

# Select top candidates
def get_speedup(result: BenchmarkResult) -> float:
Expand Down Expand Up @@ -938,40 +966,38 @@ def benchmark(
logging.warning("No candidates to benchmark.")
return []

task_list = [
BenchmarkPack(
iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(),
benchmark_timeout=tuning_client.get_benchmark_timeout_s(),
candidate_tracker=candidate_trackers[i],
)
for i in compiled_candidates
if i != 0
]
worker_context_queue = create_worker_context_queue(args.devices)
candidate_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
num_worker=len(args.devices),
task_list=task_list,
function=run_iree_benchmark_module_command,
initializer=init_worker_context,
initializer_inputs=(worker_context_queue,),
# Benchmarking baselines on each involved device.
baseline_indices = [0] * len(args.devices)
baseline_results = benchmark_candidates(
candidate_indices=baseline_indices,
devices=args.devices,
tuning_client=tuning_client,
candidate_trackers=candidate_trackers,
)

# Benchmarking baselines on each involved device.
worker_context_queue = create_worker_context_queue(args.devices)
baseline_task_list = [
BenchmarkPack(
iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(),
benchmark_timeout=tuning_client.get_benchmark_timeout_s(),
candidate_tracker=candidate_trackers[0],
)
] * len(args.devices)
baseline_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
num_worker=len(args.devices),
task_list=baseline_task_list,
function=run_iree_benchmark_module_command,
initializer=init_worker_context,
initializer_inputs=(worker_context_queue,),
baseline_times_by_device = map_baseline_by_device(baseline_results)

candidate_indices = [i for i in compiled_candidates if i != 0]
candidate_results = benchmark_candidates(
candidate_indices=candidate_indices,
devices=args.devices,
tuning_client=tuning_client,
candidate_trackers=candidate_trackers,
)

# Benchmarking baselines again to check for performance regressions.
# These may indicate machine instability, overheating, etc.
post_baseline_indices = [0] * len(args.devices)
post_baseline_results = benchmark_candidates(
candidate_indices=post_baseline_indices,
devices=args.devices,
tuning_client=tuning_client,
candidate_trackers=candidate_trackers,
)
post_baseline_times_by_device = map_baseline_by_device(post_baseline_results)
assert (
baseline_times_by_device.keys() == post_baseline_times_by_device.keys()
), "Device ID mismatch between baseline runs."

best_results: list[BenchmarkResult] = select_best_benchmark_results(
candidate_results=candidate_results,
Expand Down
26 changes: 24 additions & 2 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import argparse
import math
import pytest
import json
from subprocess import CompletedProcess
from unittest.mock import call, patch, MagicMock
from . import libtuner
Expand Down Expand Up @@ -270,3 +268,27 @@ def test_validate_baselines_device_id_match():
first_baseline, second_baseline
)
assert result is True


def test_validate_baseline_regression():
first_baseline = {"hip://0": 1000.0, "hip://1": 2000.0}
second_baseline = {"hip://0": 1100.0, "hip://1": 1900.0}
regression_devices = libtuner.validate_baseline_regression(
first_baseline, second_baseline
)
assert regression_devices == ["hip://0"]

first_baseline = {"hip://0": 1000.0, "hip://1": 2000.0}
second_baseline = {"hip://0": 1000.0, "hip://1": 2000.0}

regression_devices = libtuner.validate_baseline_regression(
first_baseline, second_baseline
)
assert regression_devices == []

first_baseline = {"hip://0": 1000.0, "hip://1": 2000.0}
second_baseline = {"hip://0": 1100.0}
regression_devices = libtuner.validate_baseline_regression(
first_baseline, second_baseline
)
assert regression_devices == ["hip://0"]

0 comments on commit ad10da7

Please sign in to comment.