Skip to content

Commit

Permalink
remove code comments and add type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 7, 2025
1 parent d85655f commit 5fbde03
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 25 deletions.
3 changes: 1 addition & 2 deletions tuner/examples/test/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import logging
from pathlib import Path
from tuner import libtuner
from iree.compiler import ir # type: ignore
from tuner.common import *


Expand Down Expand Up @@ -93,6 +91,7 @@ def main():

path_config = libtuner.PathConfig()
path_config.base_dir.mkdir(parents=True, exist_ok=True)
path_config.output_unilog.touch()
# TODO(Max191): Make candidate_trackers internal to TuningClient.
candidate_trackers: list[libtuner.CandidateTracker] = []
stop_after_phase: str = args.stop_after
Expand Down
3 changes: 0 additions & 3 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx

Expand Down
14 changes: 9 additions & 5 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
import logging
from dataclasses import astuple, dataclass, field
from enum import Enum
from types import TracebackType
from typing import Optional
from typing import Any
from typing_extensions import Literal
Expand Down Expand Up @@ -46,13 +46,17 @@ def __init__(self, logger: Optional[logging.Logger] = None):
) # Default to "tune" logger
self.type: CommonTypes = CommonTypes(self.mlir_ctx)

def __enter__(self) -> TunerContext:
def __enter__(self) -> "TunerContext":
self.mlir_ctx.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]:
self.mlir_ctx.__exit__(exc_type, exc_value, traceback)
return False
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
return self.mlir_ctx.__exit__(exc_type, exc_value, traceback)


class DispatchKind(Enum):
Expand Down
3 changes: 0 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx

Expand Down
3 changes: 0 additions & 3 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

# Mock the logger
mock_logger = MagicMock(spec=Logger)

# Use TunerContext with the mocked logger
with common.TunerContext(logger=mock_logger) as ctx:
yield ctx

Expand Down
17 changes: 8 additions & 9 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,15 +642,14 @@ def generate_candidate_specs(
# source mlir.
mlir_text = candidate_gen.strip_compilation_info(path_config.template_mlir)
mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context)
with tuning_client.tuner_context as tuner_context:
logging.debug("Captured messages from candidate_gen.py:")
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuner_context,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
)
logging.debug("Captured messages from candidate_gen.py:")
config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs(
input_module=mlir_module,
tuner_context=tuning_client.tuner_context,
limit=args.num_candidates,
num_subgroups=args.num_subgroups,
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
)
logging.debug("candidate_gen.py ends")
handle_error(
condition=(len(config_specs) <= 1), msg="Failed to generate any candidates"
Expand Down

0 comments on commit 5fbde03

Please sign in to comment.