diff --git a/src/lightning_trainable/launcher/__init__.py b/src/lightning_trainable/launcher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lightning_trainable/launcher/fit.py b/src/lightning_trainable/launcher/fit.py new file mode 100644 index 0000000..ca7584e --- /dev/null +++ b/src/lightning_trainable/launcher/fit.py @@ -0,0 +1,68 @@ +from argparse import ArgumentParser +from importlib import import_module +from pathlib import Path + +from lightning_trainable.launcher.utils import parse_config_dict + +import torch + + +def main(args=None): + parser = ArgumentParser() + parser.add_argument("--pycharm-debug", default=False, type=int, + help="Port of PyCharm remote debugger to connect to.") + log_dir_group = parser.add_mutually_exclusive_group() + log_dir_group.add_argument("--name", type=str, + help="Name of experiment. Experiment data will be stored in " + "lightning_logs/`name`/version_X") + log_dir_group.add_argument("--log-dir", type=str, + help="Experiment data will be stored `log_dir`'") + parser.add_argument("config_args", type=str, nargs="*", + help="") + args = parser.parse_args(args) + + if args.pycharm_debug: + import pydevd_pycharm + + pydevd_pycharm.settrace('localhost', port=args.pycharm_debug, stdoutToServer=True, stderrToServer=True) + + hparams = parse_config_dict(args.config_args) + + # Set number of threads (potentially move into trainable, but it's a global property) + num_threads = hparams.pop("num_threads", None) + if num_threads is not None: + torch.set_num_threads(num_threads) + + # Load the model + module_name, model_name = hparams["model"].rsplit(".", 1) + + # Log path + if args.name is not None: + logger_kwargs = dict( + save_dir="lightning_logs", + name=args.name.format( + model_name=model_name, + **hparams + ) + ) + elif args.log_dir is not None: + save_path = Path(args.save_dir) + logger_kwargs = dict( + version=save_path.name, + experiment_name=save_path.parent.name, + save_dir=save_path.parent.parent + ) + else: + logger_kwargs = dict() + + # No "model" hparam + del hparams["model"] + module = import_module(module_name) + model = getattr(module, model_name)(hparams=hparams) + + # Fit the model + return model.fit(logger_kwargs=logger_kwargs) + + +if __name__ == '__main__': + main() diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py new file mode 100644 index 0000000..03334a5 --- /dev/null +++ b/src/lightning_trainable/launcher/grid.py @@ -0,0 +1,217 @@ +import os +import platform +import subprocess +from dataclasses import dataclass +from subprocess import Popen +from collections import namedtuple, Counter +from concurrent.futures import as_completed +from concurrent.futures.thread import ThreadPoolExecutor +from itertools import product +from math import log10 +from pathlib import Path +from typing import Dict, List, Tuple +from datetime import timedelta + +from lightning_trainable.launcher.utils import send_telegram_message +from tqdm import tqdm +from yaml import dump + +ConfigWithCount = namedtuple("ConfigWithCount", ["config", "count"]) + + +@dataclass +class RunResult: + config: Dict + return_code: int | str + stdout: bytes + stderr: bytes + + +class GridLauncher: + def __init__(self, telegram_info: Dict = None): + self.running_processes: List[Popen] = [] + self.telegram_info = telegram_info + + def send_message(self, message): + if self.telegram_info is not None: + send_telegram_message(message, **self.telegram_info) + + def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_threads: int = None, + connect_debug: int = None, verbose=False, cli_args=None): + """ + Runs a single configuration using lightning_trainable.launcher.fit + in a subprocess and waits for the result. + """ + arguments = [] + if connect_debug is not None: + arguments.append("--pycharm-debug") + arguments.append(str(connect_debug)) + if cli_args is not None: + arguments.extend(cli_args) + + if num_threads is not None: + config = config + [("num_threads", num_threads)] + if len(config) > 0: + for value in config: + if isinstance(value, Path): + arguments.append(str(value)) + elif isinstance(value, tuple): + key, value = value + if isinstance(value, type): + value = f"{value.__module__}.{value.__name__}" + arguments.append(f'{key}={dump(value)}') + + out = None if verbose else subprocess.PIPE + with Popen(['python', '-m', 'lightning_trainable.launcher.fit', *arguments], + stdout=out, stderr=out, + # Signals to controller are not passed to runner + preexec_fn=None if platform.system() == "Windows" else os.setpgrp) as process: + self.running_processes.append(process) + stdout, stderr = process.communicate() + self.running_processes.remove(process) + return RunResult(config=config, return_code=process.poll(), stdout=stdout, stderr=stderr) + + def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str, list]]): + """ + Converts a grid of specifications to a list of configurations. + + Each specification can be a list of values to be passed + directly to the script or a tuple of a key and a list of values. + + For example: + grid_launcher = GridLauncher() + grid_launcher.grid_spec_to_list([ + ("model", ["tests.test_launcher.BasicTrainable"]), + ["test_launcher_config.yaml"], + ("num_threads", [1, 2, 4]), + ]) + """ + configs = [] + + fake_keys = set() + + # Create fake keys for non-tuple entries + dict_args = [] + for entry in config_spec: + if isinstance(entry, tuple): + dict_args.append(entry) + else: + fake_key = f"fake_key_{len(fake_keys)}" + fake_keys.add(fake_key) + dict_args.append((fake_key, entry)) + dict_args = dict(dict_args) + + # Create all possible combinations, removing fake keys + config_keys = list(dict_args.keys()) + for config_values in product(*dict_args.values()): + config = [ + value if key in fake_keys else (key, value) + for key, value in zip(config_keys, config_values) + ] + configs.append(config) + return configs + + def start_runs(self, configs: List[List[Path | str]], num_parallel_runs=None, + num_threads=1, connect_debug: int = None, verbose=False, cli_args=None): + """ + Starts a number of runs in parallel and returns the futures. + """ + if num_parallel_runs is None: + num_parallel_runs = max(1, os.cpu_count() // num_threads - 1) + + pool = ThreadPoolExecutor(num_parallel_runs) + futures = [] + for config in configs: + futures.append(pool.submit( + self.run_configuration, + config=config, num_threads=num_threads, + connect_debug=connect_debug, verbose=verbose, + cli_args=cli_args + )) + return pool, futures + + def run_configs_and_wait(self, + configs: List[List[Path | str]], num_parallel_runs=None, + num_threads=1, connect_debug: int = None, verbose=False, + cli_args=None) -> List[RunResult]: + """ + Runs a list of configurations in parallel and waits for the results. + """ + pool, futures = self.start_runs( + configs, + num_parallel_runs=num_parallel_runs, num_threads=num_threads, + connect_debug=connect_debug, verbose=verbose, cli_args=cli_args + ) + interrupted_count = 0 + while True: + try: + results = self.fetch_results(futures) + break + except KeyboardInterrupt: + interrupted_count += 1 + if interrupted_count == 1: + # Cancel future runs + pool.shutdown(wait=False, cancel_futures=True) + # Pool shutdown does not mark futures as_completed + # https://github.com/python/cpython/issues/87893 + for f in tqdm(futures, desc="Cancelling future runs"): + if f.cancelled(): + f.set_running_or_notify_cancel() + print("Stopped all pending experiments.") + print("Hit Ctrl-C again to cancel running experiments.") + elif interrupted_count == 2: + # Cancel current runs + for process in tqdm(self.running_processes, desc="Killing processes"): + process.kill() + print("Stopped all running experiments.") + if interrupted_count > 2: + raise KeyboardInterrupt + # Wait for remaining processes + pool.shutdown(wait=True) + + # Print results + status_counts = status_count_counter(results) + print(f"Done running {len(configs)} experiments: {status_counts}") + if len(set(status_counts) - {0}) > 0: + print(f"Total: {sum(value for key, value in status_counts.items() if key != 0)} FAILED!") + else: + print("All succeeded :D") + self.send_message(f"Launcher done: {status_counts}") + return results + + def fetch_results(self, futures, timeout=None): + """ + Fetches the results from a list of futures. + """ + last_elapsed = 60 + results = [] + with tqdm(as_completed(futures, timeout=timeout), total=len(futures), smoothing=0) as pbar: + for future in pbar: + if future.cancelled(): + result = RunResult(None, "cancelled", None, None) + else: + result: RunResult = future.result() + results.append(result) + + status_counts = status_count_counter(results) + result_code = result.return_code + + elapsed = pbar.format_dict["elapsed"] + elapsed_delta = timedelta(seconds=elapsed) + if result_code != 0 and status_counts[result_code] == 10 ** int( + log10(status_counts[result_code])): + self.send_message( + f"Code {result_code}: {status_counts[result_code]} " + f"failed after {elapsed_delta}." + ) + elif result_code == 0 and elapsed > last_elapsed * 2: + self.send_message( + f"{status_counts[result_code]} succeeded after {elapsed_delta}." + ) + last_elapsed = elapsed + pbar.set_description(str(status_counts)) + return results + + +def status_count_counter(results: List[RunResult]) -> Counter: + return Counter(result.return_code for result in results) diff --git a/src/lightning_trainable/launcher/utils.py b/src/lightning_trainable/launcher/utils.py new file mode 100644 index 0000000..ce0f44d --- /dev/null +++ b/src/lightning_trainable/launcher/utils.py @@ -0,0 +1,46 @@ +from pathlib import Path +from typing import List +from urllib.parse import urlencode +from urllib.request import Request, urlopen + +from yaml import safe_load + + +def parse_config_dict(config_spec: List[Path | str]): + hparams = {} + for arg in config_spec: + if isinstance(arg, Path) or ( + any( + arg.endswith(suffix) for suffix in [".yaml", ".yml", ".json"] + ) and "=" not in arg + ): + # Read multiple entries from .yaml file + with open(arg, "r") as file: + new_hparams = safe_load(file) + else: + # Read single entry from command line + key, value = arg.split("=") + new_hparams = {key: safe_load(value)} + + # Merge in new parameters + for key, value in new_hparams.items(): + hparam_level = hparams + key_path = key.split(".") + for key_entry in key_path[:-1]: + hparam_level = hparam_level[key_entry] + hparam_level[key_path[-1]] = value + return hparams + + +def send_telegram_message(message: str, token: str, chats: List[int]): + try: + url = f"https://api.telegram.org/bot{token}/sendMessage" + for chat_id in chats: + params = { + "chat_id": chat_id, + "text": message + } + request = Request(url, urlencode(params).encode()) + urlopen(request).read().decode() + except FileNotFoundError: + pass diff --git a/tests/test_launcher.py b/tests/test_launcher.py new file mode 100644 index 0000000..4a9b52d --- /dev/null +++ b/tests/test_launcher.py @@ -0,0 +1,60 @@ +import os +from pathlib import Path + +import torch +from lightning_trainable import Trainable, TrainableHParams +from lightning_trainable.launcher.fit import main +from torch.utils.data import TensorDataset + +from lightning_trainable.launcher.grid import GridLauncher, status_count_counter + + +class BasicTrainableHParams(TrainableHParams): + domain: list + + +class BasicTrainable(Trainable): + def __init__(self, hparams: BasicTrainableHParams | dict): + if not isinstance(hparams, BasicTrainableHParams): + hparams = BasicTrainableHParams(**hparams) + + x = torch.linspace(*hparams.domain, 1000)[:, None] + y = torch.sin(x)[:, None] + train_data = TensorDataset(x, y) + + super().__init__(hparams, train_data=train_data) + + self.model = torch.nn.Sequential( + torch.nn.Linear(1, 16), torch.nn.ReLU(), + torch.nn.Linear(16, 1) + ) + + def compute_metrics(self, batch, batch_idx) -> dict: + x, y = batch + return { + "loss": ((self.model(x) - y) ** 2).mean() + } + + +def test_fit_launcher(): + main([ + "model=tests.test_launcher.BasicTrainable", + str(Path(__file__).parent / "test_launcher_config.yaml"), + "max_epochs=1", + "domain=[-5, 5]", + "accelerator='cpu'", + "--name", "{model_name};{max_epochs}" + ]) + + +def test_grid_launcher(): + launcher = GridLauncher() + config_list = launcher.grid_spec_to_list([ + ("model", ["tests.test_launcher.BasicTrainable"]), + ([Path(__file__).parent / "test_launcher_config.yaml"]), + ("max_epochs", [1]), + ("domain", [[-5, 5], [-3, 3]]), + ("accelerator", ['cpu']) + ]) + results = launcher.run_configs_and_wait(config_list, cli_args=["--name", "{model_name};{max_epochs}"]) + assert status_count_counter(results) == {0: 2} diff --git a/tests/test_launcher_config.yaml b/tests/test_launcher_config.yaml new file mode 100644 index 0000000..1750891 --- /dev/null +++ b/tests/test_launcher_config.yaml @@ -0,0 +1 @@ +batch_size: 128