From ed7906f0517425580031bcfdad8b09814a1e5ba7 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 8 Jan 2025 00:41:14 -0500 Subject: [PATCH] [Tuner] Fix context management (#770) This PR is about addressing the MLIR context management issue in the tuner detailed in https://github.com/nod-ai/shark-ai/issues/764. Although this is a work in progress, I am sending it to gather feedback and ensure I am heading in the right direction. --------- Signed-off-by: Bangtian Liu --- tuner/examples/simple/simple_tuner.py | 130 ++++++++++++----------- tuner/tuner/candidate_gen_test.py | 6 +- tuner/tuner/common.py | 23 ++-- tuner/tuner/common_test.py | 6 +- tuner/tuner/dispatch_constraints_test.py | 6 +- tuner/tuner/dispatch_parser_test.py | 6 +- tuner/tuner/libtuner.py | 26 ++--- tuner/tuner/op_matchers.py | 2 +- 8 files changed, 107 insertions(+), 98 deletions(-) diff --git a/tuner/examples/simple/simple_tuner.py b/tuner/examples/simple/simple_tuner.py index d78ec5b53..63421fdfe 100644 --- a/tuner/examples/simple/simple_tuner.py +++ b/tuner/examples/simple/simple_tuner.py @@ -7,11 +7,12 @@ import argparse from pathlib import Path from tuner import libtuner +from tuner.common import * class TestTuner(libtuner.TuningClient): - def __init__(self): - super().__init__() + def __init__(self, tuner_context: libtuner.TunerContext): + super().__init__(tuner_context) self.compile_flags = ["--compile-from=executable-sources"] self.benchmark_flags = ["--benchmark_repetitions=3", "--input=1"] @@ -104,68 +105,69 @@ def main(): print("Validation successful!\n") print("Generating candidate tuning specs...") - test_tuner = TestTuner() - candidates = libtuner.generate_candidate_specs( - args, path_config, candidate_trackers, test_tuner - ) - print(f"Stored candidate tuning specs in {path_config.specs_dir}\n") - if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: - return - - print("Compiling dispatch candidates...") - compiled_candidates = libtuner.compile( - args, path_config, candidates, candidate_trackers, test_tuner - ) - if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: - return - - print("Benchmarking compiled dispatch candidates...") - top_candidates = libtuner.benchmark( - args, - path_config, - compiled_candidates, - candidate_trackers, - test_tuner, - args.simple_num_dispatch_candidates, - ) - if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: - return - - print("Compiling models with top candidates...") - test_tuner.compile_flags = [ - "--iree-hal-target-backends=rocm", - f"--iree-hip-target={args.simple_hip_target}", - ] - compiled_model_candidates = libtuner.compile( - args, - path_config, - top_candidates, - candidate_trackers, - test_tuner, - args.simple_model_file, - ) - if stop_after_phase == libtuner.ExecutionPhases.compile_models: - return - - print("Benchmarking compiled model candidates...") - test_tuner.benchmark_flags = [ - "--benchmark_repetitions=3", - "--input=2048x2048xf16", - "--input=2048x2048xf16", - ] - top_model_candidates = libtuner.benchmark( - args, - path_config, - compiled_model_candidates, - candidate_trackers, - test_tuner, - args.simple_num_model_candidates, - ) - - print(f"Top model candidates: {top_model_candidates}") - - print("Check the detailed execution logs in:") - print(path_config.run_log.resolve()) + with TunerContext() as tuner_context: + test_tuner = TestTuner(tuner_context) + candidates = libtuner.generate_candidate_specs( + args, path_config, candidate_trackers, test_tuner + ) + print(f"Stored candidate tuning specs in {path_config.specs_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling dispatch candidates...") + compiled_candidates = libtuner.compile( + args, path_config, candidates, candidate_trackers, test_tuner + ) + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled dispatch candidates...") + top_candidates = libtuner.benchmark( + args, + path_config, + compiled_candidates, + candidate_trackers, + test_tuner, + args.simple_num_dispatch_candidates, + ) + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + print("Compiling models with top candidates...") + test_tuner.compile_flags = [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={args.simple_hip_target}", + ] + compiled_model_candidates = libtuner.compile( + args, + path_config, + top_candidates, + candidate_trackers, + test_tuner, + args.simple_model_file, + ) + if stop_after_phase == libtuner.ExecutionPhases.compile_models: + return + + print("Benchmarking compiled model candidates...") + test_tuner.benchmark_flags = [ + "--benchmark_repetitions=3", + "--input=2048x2048xf16", + "--input=2048x2048xf16", + ] + top_model_candidates = libtuner.benchmark( + args, + path_config, + compiled_model_candidates, + candidate_trackers, + test_tuner, + args.simple_num_model_candidates, + ) + + print(f"Top model candidates: {top_model_candidates}") + + print("Check the detailed execution logs in:") + print(path_config.run_log.resolve()) for candidate in candidate_trackers: libtuner.logging.debug(candidate) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 8b0ca58d3..6a62e90e4 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -27,9 +27,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + mock_logger = MagicMock(spec=Logger) + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 45bcb0d75..8efac1653 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -4,17 +4,16 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import re import logging from dataclasses import astuple, dataclass, field from enum import Enum +from types import TracebackType from typing import Optional from typing import Any from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore -from iree.compiler.dialects import iree_codegen # type: ignore class CommonTypes: @@ -38,10 +37,22 @@ def getI64(self, value: int) -> ir.IntegerAttr: class TunerContext: - def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): - self.mlir_ctx: ir.Context = mlir_ctx - self.logger: logging.Logger = logger - self.type: CommonTypes = CommonTypes(mlir_ctx) + def __init__(self, logger: Optional[logging.Logger] = None): + self.mlir_ctx: ir.Context = ir.Context() + self.logger: logging.Logger = logger or logging.getLogger("tune") + self.type: CommonTypes = CommonTypes(self.mlir_ctx) + + def __enter__(self) -> "TunerContext": + self.mlir_ctx.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + return self.mlir_ctx.__exit__(exc_type, exc_value, traceback) class DispatchKind(Enum): diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index eba5b35e1..a6c71026d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -23,9 +23,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + mock_logger = MagicMock(spec=Logger) + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx @pytest.fixture diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 1116adac3..9a34e41db 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -25,9 +25,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + mock_logger = MagicMock(spec=Logger) + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 7ddb0bb84..204f84b28 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -27,9 +27,9 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - with ir.Context() as ctx: - logger: Logger = MagicMock(spec=Logger) - yield common.TunerContext(ctx, logger) + mock_logger = MagicMock(spec=Logger) + with common.TunerContext(logger=mock_logger) as ctx: + yield ctx CONTRACTION_TEMPLATE = r""" diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index fab86c369..b18736ffb 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -18,7 +18,6 @@ import math import signal -import subprocess import sys import shutil import logging @@ -26,7 +25,6 @@ from datetime import datetime from enum import Enum from pathlib import Path -import time import multiprocessing import queue from tqdm import tqdm @@ -37,6 +35,7 @@ import iree.runtime as ireert # type: ignore import iree.compiler as ireec # type: ignore from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from . import candidate_gen from . import dispatch_parser from .op_matchers import * @@ -103,10 +102,8 @@ def get_candidate_vmfb_filename(self, candidate_id: int) -> str: class TuningClient(ABC): - def __init__(self): - mlir_ctx = ir.Context() - logger = logging.getLogger("tune") - self.tuner_context = TunerContext(mlir_ctx, logger) + def __init__(self, tuner_context: TunerContext): + self.tuner_context = tuner_context @abstractmethod def get_iree_compile_flags(self) -> list[str]: @@ -644,15 +641,14 @@ def generate_candidate_specs( # source mlir. mlir_text = candidate_gen.strip_compilation_info(path_config.template_mlir) mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context) - with tuning_client.tuner_context.mlir_ctx: - logging.debug("Captured messages from candidate_gen.py:") - config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( - input_module=mlir_module, - tuner_context=tuning_client.tuner_context, - limit=args.num_candidates, - num_subgroups=args.num_subgroups, - codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline), - ) + logging.debug("Captured messages from candidate_gen.py:") + config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( + input_module=mlir_module, + tuner_context=tuning_client.tuner_context, + limit=args.num_candidates, + num_subgroups=args.num_subgroups, + codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline), + ) logging.debug("candidate_gen.py ends") handle_error( condition=(len(config_specs) <= 1), msg="Failed to generate any candidates" diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index f3966b97d..09fcd17ea 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -125,7 +125,7 @@ def get_map_result_dim_positions(map: ir.AffineMap): class ContractionOpInterfaceMatcher(GenericOpMatcher): - def __init__(self): + def __init__(self) -> None: super().__init__() self.contraction_dimensions: Optional[ContractionDimensions] = None self.lhs_dims: Optional[list[int]] = None