diff --git a/tuner/examples/simple/simple_tuner.py b/tuner/examples/simple/simple_tuner.py index bd5b2eca1..d4ec089fe 100644 --- a/tuner/examples/simple/simple_tuner.py +++ b/tuner/examples/simple/simple_tuner.py @@ -15,10 +15,14 @@ def __init__(self, tuner_context: libtuner.TunerContext): super().__init__(tuner_context) self.compile_flags: list[str] = [] self.benchmark_flags: list[str] = [] + self.compile_timeout: int = 10 def get_iree_compile_flags(self) -> list[str]: return self.compile_flags + def get_iree_compile_timeout_s(self) -> int: + return self.compile_timeout + def get_iree_benchmark_module_flags(self) -> list[str]: return self.benchmark_flags @@ -123,6 +127,7 @@ def main(): print("Compiling models with top candidates...") simple_tuner.compile_flags = compile_flags + simple_tuner.compile_timeout = 60 compiled_model_candidates = libtuner.compile( args, path_config, diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 6184d0c95..63740ee9b 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -110,6 +110,10 @@ def __init__(self, tuner_context: TunerContext): def get_iree_compile_flags(self) -> list[str]: pass + @abstractmethod + def get_iree_compile_timeout_s(self) -> int: + pass + @abstractmethod def get_iree_benchmark_module_flags(self) -> list[str]: pass @@ -122,6 +126,7 @@ def get_benchmark_timeout_s(self) -> int: @dataclass class CompilePack: iree_compile_flags: list[str] + iree_compile_timeout: int candidate_tracker: CandidateTracker @@ -440,30 +445,29 @@ def run_iree_compile_command(compile_pack: CompilePack) -> Optional[int]: logging.debug( f"Compiling candidate {candidate_tracker.candidate_id} with spec: {td_spec_path}" ) - extra_flags = [ - f"--iree-codegen-tuning-spec-path={td_spec_path}", - ] - extra_flags += compile_pack.iree_compile_flags assert candidate_tracker.compiled_vmfb_path, "expected output vmfb path" output_path = candidate_tracker.compiled_vmfb_path.as_posix() crash_dump_path = f"{output_path}.crash_report.mlir" assert candidate_tracker.mlir_path, "expected input mlir file path" input_file = candidate_tracker.mlir_path.as_posix() - # TODO(Max191): Make the device in `traget_backends` a command line option - # instead of hardcoding in ireec.compile_str. - try: - ireec.compile_file( - input_file=input_file, - target_backends=["rocm"], - output_file=output_path, - extra_args=extra_flags, - crash_reproducer_path=crash_dump_path, + iree_compile = ireec.binaries.find_tool("iree-compile") + compile_command = [ + iree_compile, + input_file, + f"-o={output_path}", + f"--mlir-pass-pipeline-crash-reproducer={crash_dump_path}", + f"--iree-codegen-tuning-spec-path={td_spec_path}", + ] + compile_command += compile_pack.iree_compile_flags + result = candidate_gen.run_command( + candidate_gen.RunPack( + command=compile_command, + check=False, + timeout_seconds=compile_pack.iree_compile_timeout, ) - except ireec.CompilerToolError as e: - logging.info(f"Compilation returned non-zero exit status.") - logging.debug(e) + ) + if result.process_res is None or result.is_timeout: return None - return candidate_tracker.candidate_id @@ -775,6 +779,7 @@ def compile( task_list = [ CompilePack( iree_compile_flags=tuning_client.get_iree_compile_flags(), + iree_compile_timeout=tuning_client.get_iree_compile_timeout_s(), candidate_tracker=candidate_trackers[i], ) for i in candidates @@ -783,6 +788,7 @@ def compile( task_list.append( CompilePack( iree_compile_flags=tuning_client.get_iree_compile_flags(), + iree_compile_timeout=tuning_client.get_iree_compile_timeout_s(), candidate_tracker=candidate_trackers[0], ) )