Skip to content

Commit

Permalink
[Tuner] Fix context management (#770)
Browse files Browse the repository at this point in the history
This PR is about addressing the MLIR context management issue in the
tuner detailed in #764.

Although this is a work in progress, I am sending it to gather feedback
and ensure I am heading in the right direction.

---------

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Jan 8, 2025
1 parent 64dfcb2 commit ed7906f
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 98 deletions.
130 changes: 66 additions & 64 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import argparse
from pathlib import Path
from tuner import libtuner
from tuner.common import *


class TestTuner(libtuner.TuningClient):
def __init__(self):
super().__init__()
def __init__(self, tuner_context: libtuner.TunerContext):
super().__init__(tuner_context)
self.compile_flags = ["--compile-from=executable-sources"]
self.benchmark_flags = ["--benchmark_repetitions=3", "--input=1"]

Expand Down Expand Up @@ -104,68 +105,69 @@ def main():
print("Validation successful!\n")

print("Generating candidate tuning specs...")
test_tuner = TestTuner()
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.simple_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())
with TunerContext() as tuner_context:
test_tuner = TestTuner(tuner_context)
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.simple_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
6 changes: 3 additions & 3 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
Expand Down
23 changes: 17 additions & 6 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import re
import logging
from dataclasses import astuple, dataclass, field
from enum import Enum
from types import TracebackType
from typing import Optional
from typing import Any

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_gpu # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore


class CommonTypes:
Expand All @@ -38,10 +37,22 @@ def getI64(self, value: int) -> ir.IntegerAttr:


class TunerContext:
def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger):
self.mlir_ctx: ir.Context = mlir_ctx
self.logger: logging.Logger = logger
self.type: CommonTypes = CommonTypes(mlir_ctx)
def __init__(self, logger: Optional[logging.Logger] = None):
self.mlir_ctx: ir.Context = ir.Context()
self.logger: logging.Logger = logger or logging.getLogger("tune")
self.type: CommonTypes = CommonTypes(self.mlir_ctx)

def __enter__(self) -> "TunerContext":
self.mlir_ctx.__enter__()
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
return self.mlir_ctx.__exit__(exc_type, exc_value, traceback)


class DispatchKind(Enum):
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)
mock_logger = MagicMock(spec=Logger)
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx


CONTRACTION_TEMPLATE = r"""
Expand Down
26 changes: 11 additions & 15 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@

import math
import signal
import subprocess
import sys
import shutil
import logging
import argparse
from datetime import datetime
from enum import Enum
from pathlib import Path
import time
import multiprocessing
import queue
from tqdm import tqdm
Expand All @@ -37,6 +35,7 @@
import iree.runtime as ireert # type: ignore
import iree.compiler as ireec # type: ignore
from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore
from . import candidate_gen
from . import dispatch_parser
from .op_matchers import *
Expand Down Expand Up @@ -103,10 +102,8 @@ def get_candidate_vmfb_filename(self, candidate_id: int) -> str:


class TuningClient(ABC):
def __init__(self):
mlir_ctx = ir.Context()
logger = logging.getLogger("tune")
self.tuner_context = TunerContext(mlir_ctx, logger)
def __init__(self, tuner_context: TunerContext):
self.tuner_context = tuner_context

@abstractmethod
def get_iree_compile_flags(self) -> list[str]:
Expand Down Expand Up @@ -644,15 +641,14 @@ def generate_candidate_specs(
# source mlir.
mlir_text = candidate_gen.strip_compilation_info(path_config.template_mlir)
mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context)
with tuning_client.tuner_context.mlir_ctx:
logging.debug("Captured messages from candidate_gen.py:")
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuning_client.tuner_context,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
)
logging.debug("Captured messages from candidate_gen.py:")
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuning_client.tuner_context,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
)
logging.debug("candidate_gen.py ends")
handle_error(
condition=(len(config_specs) <= 1), msg="Failed to generate any candidates"
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_map_result_dim_positions(map: ir.AffineMap):


class ContractionOpInterfaceMatcher(GenericOpMatcher):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.contraction_dimensions: Optional[ContractionDimensions] = None
self.lhs_dims: Optional[list[int]] = None
Expand Down

0 comments on commit ed7906f

Please sign in to comment.