Skip to content

Commit

Permalink
address reviewer comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 7, 2025
1 parent 348e047 commit 6bb9652
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 89 deletions.
116 changes: 57 additions & 59 deletions tuner/examples/test/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,65 +110,63 @@ def main():
print("Validation successful!\n")

print("Generating candidates...")
mlir_ctx = ir.Context()
logger = logging.getLogger("tune")
tuner_context = TunerContext(mlir_ctx, logger)
test_tuner = TestTuner(tuner_context)
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)

print("Benchmarking compiled candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.test_num_dispatch_candidates,
)

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.test_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.test_model_file,
)

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.test_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())
with TunerContext() as tuner_context:
test_tuner = TestTuner(tuner_context)
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)

print("Benchmarking compiled candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.test_num_dispatch_candidates,
)

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.test_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.test_model_file,
)

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.test_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
9 changes: 6 additions & 3 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
Expand Down
19 changes: 11 additions & 8 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import re
from __future__ import annotations
import logging
from dataclasses import astuple, dataclass, field
from enum import Enum
from typing import Optional
from typing import Any
from typing_extensions import Literal

from iree.compiler import ir # type: ignore

Expand Down Expand Up @@ -38,16 +39,18 @@ def getI64(self, value: int) -> ir.IntegerAttr:


class TunerContext:
def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger):
self.mlir_ctx: ir.Context = mlir_ctx
self.logger: logging.Logger = logger
self.type: CommonTypes = CommonTypes(mlir_ctx)

def __enter__(self):
def __init__(self, logger: Optional[logging.Logger] = None):
self.mlir_ctx: ir.Context = ir.Context()
self.logger: logging.Logger = logger or logging.getLogger(
"tune"
) # Default to "tune" logger
self.type: CommonTypes = CommonTypes(self.mlir_ctx)

def __enter__(self) -> TunerContext:
self.mlir_ctx.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]:
self.mlir_ctx.__exit__(exc_type, exc_value, traceback)
return False

Expand Down
9 changes: 6 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


@pytest.fixture
Expand Down
9 changes: 6 additions & 3 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
Expand Down
9 changes: 6 additions & 3 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


CONTRACTION_TEMPLATE = r"""
Expand Down
10 changes: 0 additions & 10 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ class TuningClient(ABC):
def __init__(self, tuner_context: TunerContext):
self.tuner_context = tuner_context

def __enter__(self):
# Enter the context of TunerContext
self.tuner_context.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
# Exit the context of TunerContext
self.tuner_context.__exit__(exc_type, exc_value, traceback)
return False

@abstractmethod
def get_iree_compile_flags(self) -> list[str]:
pass
Expand Down

0 comments on commit 6bb9652

Please sign in to comment.