From 5fbde031339d84d672bd375401654b2e29235598 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 7 Jan 2025 14:58:23 -0600 Subject: [PATCH] remove code comments and add type hints Signed-off-by: Bangtian Liu --- tuner/examples/test/tuner_test.py | 3 +-- tuner/tuner/candidate_gen_test.py | 3 --- tuner/tuner/common.py | 14 +++++++++----- tuner/tuner/common_test.py | 3 --- tuner/tuner/dispatch_constraints_test.py | 3 --- tuner/tuner/libtuner.py | 17 ++++++++--------- 6 files changed, 18 insertions(+), 25 deletions(-) diff --git a/tuner/examples/test/tuner_test.py b/tuner/examples/test/tuner_test.py index c4e80b119..7d4fa58bf 100644 --- a/tuner/examples/test/tuner_test.py +++ b/tuner/examples/test/tuner_test.py @@ -5,10 +5,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse -import logging from pathlib import Path from tuner import libtuner -from iree.compiler import ir # type: ignore from tuner.common import * @@ -93,6 +91,7 @@ def main(): path_config = libtuner.PathConfig() path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() # TODO(Max191): Make candidate_trackers internal to TuningClient. candidate_trackers: list[libtuner.CandidateTracker] = [] stop_after_phase: str = args.stop_after diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 990ffc22c..6a62e90e4 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -27,10 +27,7 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - # Mock the logger mock_logger = MagicMock(spec=Logger) - - # Use TunerContext with the mocked logger with common.TunerContext(logger=mock_logger) as ctx: yield ctx diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index cb65bcad2..b69ffe0e5 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -4,10 +4,10 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from __future__ import annotations import logging from dataclasses import astuple, dataclass, field from enum import Enum +from types import TracebackType from typing import Optional from typing import Any from typing_extensions import Literal @@ -46,13 +46,17 @@ def __init__(self, logger: Optional[logging.Logger] = None): ) # Default to "tune" logger self.type: CommonTypes = CommonTypes(self.mlir_ctx) - def __enter__(self) -> TunerContext: + def __enter__(self) -> "TunerContext": self.mlir_ctx.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: - self.mlir_ctx.__exit__(exc_type, exc_value, traceback) - return False + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + return self.mlir_ctx.__exit__(exc_type, exc_value, traceback) class DispatchKind(Enum): diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index b46a96919..a6c71026d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -23,10 +23,7 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - # Mock the logger mock_logger = MagicMock(spec=Logger) - - # Use TunerContext with the mocked logger with common.TunerContext(logger=mock_logger) as ctx: yield ctx diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 8b8bd1002..fafb37102 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -25,10 +25,7 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: from logging import Logger from unittest.mock import MagicMock - # Mock the logger mock_logger = MagicMock(spec=Logger) - - # Use TunerContext with the mocked logger with common.TunerContext(logger=mock_logger) as ctx: yield ctx diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 006db0bfb..9bcfb5077 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -642,15 +642,14 @@ def generate_candidate_specs( # source mlir. mlir_text = candidate_gen.strip_compilation_info(path_config.template_mlir) mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context) - with tuning_client.tuner_context as tuner_context: - logging.debug("Captured messages from candidate_gen.py:") - config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( - input_module=mlir_module, - tuner_context=tuner_context, - limit=args.num_candidates, - num_subgroups=args.num_subgroups, - codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline), - ) + logging.debug("Captured messages from candidate_gen.py:") + config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( + input_module=mlir_module, + tuner_context=tuning_client.tuner_context, + limit=args.num_candidates, + num_subgroups=args.num_subgroups, + codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline), + ) logging.debug("candidate_gen.py ends") handle_error( condition=(len(config_specs) <= 1), msg="Failed to generate any candidates"