From 2b5e1c38e4f96739ed3abf182d8baa8e2ab5bd26 Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Fri, 24 Mar 2023 09:46:09 +0100 Subject: [PATCH 1/9] Single model fit script reading in lots of configs --- src/lightning_trainable/launcher/__init__.py | 0 src/lightning_trainable/launcher/fit.py | 76 ++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 src/lightning_trainable/launcher/__init__.py create mode 100644 src/lightning_trainable/launcher/fit.py diff --git a/src/lightning_trainable/launcher/__init__.py b/src/lightning_trainable/launcher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lightning_trainable/launcher/fit.py b/src/lightning_trainable/launcher/fit.py new file mode 100644 index 0000000..c609661 --- /dev/null +++ b/src/lightning_trainable/launcher/fit.py @@ -0,0 +1,76 @@ +from argparse import ArgumentParser +from importlib import import_module +from pathlib import Path + +from yaml import full_load, safe_load + +import torch + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument("--pycharm-debug", default=False, type=int, + help="Port of PyCharm remote debugger to connect to.") + log_dir_group = parser.add_mutually_exclusive_group() + log_dir_group.add_argument("--name", type=str, + help="Name of experiment. Experiment data will be stored in " + "lightning_logs/`name`/version_X") + log_dir_group.add_argument("--log-dir", type=str, + help="Experiment data will be stored `log_dir`'") + parser.add_argument("config_args", type=str, nargs="*", + help="") + args = parser.parse_args() + + if args.pycharm_debug: + import pydevd_pycharm + + pydevd_pycharm.settrace('localhost', port=args.pycharm_debug, stdoutToServer=True, stderrToServer=True) + + hparams = {} + for arg in args.config_args: + if "=" not in arg and any( + arg.endswith(suffix) for suffix in [".yaml", ".yml", ".json"] + ): + # Read multiple entries from .yaml file + with open(arg, "r") as file: + new_hparams = full_load(file) + else: + # Read single entry from command line + key, value = arg.split("=") + new_hparams = {key: safe_load(value)} + + # Merge in new parameters + for key, value in new_hparams.items(): + hparam_level = hparams + key_path = key.split(".") + for key_entry in key_path[:-1]: + hparam_level = hparam_level[key_entry] + hparam_level[key_path[-1]] = value + + # Set number of threads (potentially move into trainable, but it's a global property) + num_threads = hparams.get("num_threads", None) + if num_threads is not None: + torch.set_num_threads(num_threads) + + # Load the model + module_name, model_name = hparams.pop("model") + module = import_module(module_name) + model = getattr(module, model_name)(hparams=hparams) + + # Log path + if args.name is not None: + logger_kwargs = dict( + save_dir="lightning_logs", + name=args.name + ) + elif args.save_dir is not None: + save_path = Path(args.save_dir) + logger_kwargs = dict( + version=save_path.name, + experiment_name=save_path.parent.name, + save_dir=save_path.parent.parent + ) + else: + logger_kwargs = dict() + + # Fit the model + model.fit(logger_kwargs=logger_kwargs) From 18b58a7b200340bf7e0666e9e675d0b84786702b Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Fri, 24 Mar 2023 10:20:32 +0100 Subject: [PATCH 2/9] Add grid fit launcher --- src/lightning_trainable/launcher/fit.py | 23 +-- src/lightning_trainable/launcher/grid.py | 178 ++++++++++++++++++++++ src/lightning_trainable/launcher/utils.py | 45 ++++++ 3 files changed, 225 insertions(+), 21 deletions(-) create mode 100644 src/lightning_trainable/launcher/grid.py create mode 100644 src/lightning_trainable/launcher/utils.py diff --git a/src/lightning_trainable/launcher/fit.py b/src/lightning_trainable/launcher/fit.py index c609661..e3c6c9b 100644 --- a/src/lightning_trainable/launcher/fit.py +++ b/src/lightning_trainable/launcher/fit.py @@ -2,7 +2,7 @@ from importlib import import_module from pathlib import Path -from yaml import full_load, safe_load +from lightning_trainable.launcher.utils import parse_config_dict import torch @@ -25,26 +25,7 @@ pydevd_pycharm.settrace('localhost', port=args.pycharm_debug, stdoutToServer=True, stderrToServer=True) - hparams = {} - for arg in args.config_args: - if "=" not in arg and any( - arg.endswith(suffix) for suffix in [".yaml", ".yml", ".json"] - ): - # Read multiple entries from .yaml file - with open(arg, "r") as file: - new_hparams = full_load(file) - else: - # Read single entry from command line - key, value = arg.split("=") - new_hparams = {key: safe_load(value)} - - # Merge in new parameters - for key, value in new_hparams.items(): - hparam_level = hparams - key_path = key.split(".") - for key_entry in key_path[:-1]: - hparam_level = hparam_level[key_entry] - hparam_level[key_path[-1]] = value + hparams = parse_config_dict(args.config_args) # Set number of threads (potentially move into trainable, but it's a global property) num_threads = hparams.get("num_threads", None) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py new file mode 100644 index 0000000..4b409e8 --- /dev/null +++ b/src/lightning_trainable/launcher/grid.py @@ -0,0 +1,178 @@ +import os +import subprocess +from dataclasses import dataclass +from subprocess import Popen +from collections import namedtuple, Counter +from concurrent.futures import as_completed +from concurrent.futures.thread import ThreadPoolExecutor +from itertools import product +from math import log10 +from pathlib import Path +from typing import Dict, List +from datetime import timedelta + +from lightning_trainable.launcher.utils import send_telegram_message +from tqdm import tqdm +from yaml import dump + +ConfigWithCount = namedtuple("ConfigWithCount", ["config", "count"]) + + +@dataclass +class RunResult: + config: Dict + return_code: int | str + stdout: bytes + stderr: bytes + + +class ExperimentLauncher: + def __init__(self, telegram_info: Dict = None): + self.running_processes: List[Popen] = [] + self.telegram_info = telegram_info + + def send_message(self, message): + if self.telegram_info is not None: + send_telegram_message(message, **self.telegram_info) + + def run_configuration(self, config, num_threads: int = None, connect_debug: int = None, verbose=False): + """ + Runs a single configuration using lightning_trainable.launcher.fit + in a subprocess and waits for the result. + """ + arguments = [] + if connect_debug is not None: + arguments.append("--pycharm-debug") + arguments.append(str(connect_debug)) + + all_config = { + "num_threads": num_threads, + **config + } + if len(all_config) > 0: + for key, value in all_config.items(): + arguments.append(f'{key}={dump(value)}') + + out = None if verbose else subprocess.PIPE + with Popen(['python', '-m', 'lightning_trainable.launcher.fit', *arguments], + stdout=out, stderr=out, + # Signals to controller are not passed to runner + preexec_fn=os.setpgrp) as process: + self.running_processes.append(process) + stdout, stderr = process.communicate() + self.running_processes.remove(process) + return RunResult(config=config, return_code=process.poll(), stdout=stdout, stderr=stderr) + + def grid_spec_to_list(self, config_spec: Dict[str, list]): + """ + Converts a grid of name=list[... values] to a list of dict configurations. + """ + config_keys = list(config_spec.keys()) + configs = [] + for config_values in product(*config_spec.values()): + config = dict(zip(config_keys, config_values)) + configs.append(config) + return configs + + def start_runs(self, configs: List[List[Path, str]], num_parallel_runs=None, + num_threads=None, connect_debug: int = None, verbose=False): + """ + Starts a number of runs in parallel and returns the futures. + """ + if num_parallel_runs is None: + num_parallel_runs = max(1, os.cpu_count() // num_threads - 1) + + pool = ThreadPoolExecutor(num_parallel_runs) + futures = [] + for config in configs: + futures.append(pool.submit( + self.run_configuration, + config=config, num_threads=num_threads, + connect_debug=connect_debug, verbose=verbose + )) + return pool, futures + + def run_configs_and_wait(self, + configs: List[List[Path, str]], num_parallel_runs=None, + num_threads=None, connect_debug: int = None, verbose=False) -> List[RunResult]: + """ + Runs a list of configurations in parallel and waits for the results. + """ + pool, futures = self.start_runs( + configs, + num_parallel_runs=num_parallel_runs, num_threads=num_threads, + connect_debug=connect_debug, verbose=verbose + ) + interrupted_count = 0 + while True: + try: + results = self.fetch_results(futures) + break + except KeyboardInterrupt: + interrupted_count += 1 + if interrupted_count == 1: + # Cancel future runs + pool.shutdown(wait=False, cancel_futures=True) + # Pool shutdown does not mark futures as_completed + # https://github.com/python/cpython/issues/87893 + for f in tqdm(futures, desc="Cancelling future runs"): + if f.cancelled(): + f.set_running_or_notify_cancel() + print("Stopped all pending experiments.") + print("Hit Ctrl-C again to cancel running experiments.") + elif interrupted_count == 2: + # Cancel current runs + for process in tqdm(self.running_processes, desc="Killing processes"): + process.kill() + print("Stopped all running experiments.") + if interrupted_count > 2: + raise KeyboardInterrupt + # Wait for remaining processes + pool.shutdown(wait=True) + + # Print results + status_counts = status_count_counter(results) + print(f"Done running {sum([config.count for config in configs])} experiments: {status_counts}") + if len(set(status_counts) - {0}) > 0: + print(f"Total: {sum(value for key, value in status_counts.items() if key != 0)} FAILED!") + else: + print("All succeeded :D") + self.send_message(f"Launcher done: {status_counts}") + return results + + def fetch_results(self, futures, timeout=None): + """ + Fetches the results from a list of futures. + """ + last_elapsed = 60 + results = [] + with tqdm(as_completed(futures, timeout=timeout), total=len(futures), smoothing=0) as pbar: + for future in pbar: + if future.cancelled(): + result = RunResult(None, "cancelled", None, None) + else: + result: RunResult = future.result() + results.append(result) + + status_counts = status_count_counter(results) + result_code = result.return_code + + elapsed = pbar.format_dict["elapsed"] + elapsed_delta = timedelta(seconds=elapsed) + if result_code != 0 and status_counts[result_code] == 10 ** int( + log10(status_counts[result_code])): + self.send_message( + f"Code {result_code}: {status_counts[result_code]} " + f"failed after {elapsed_delta}." + ) + elif result_code == 0 and elapsed > last_elapsed * 2: + self.send_message( + f"{status_counts[result_code]} succeeded after {elapsed_delta}." + ) + last_elapsed = elapsed + pbar.set_description(str(status_counts)) + return results + + +def status_count_counter(results: List[RunResult]) -> Counter: + return Counter(result.return_code for result in results) diff --git a/src/lightning_trainable/launcher/utils.py b/src/lightning_trainable/launcher/utils.py new file mode 100644 index 0000000..03bd158 --- /dev/null +++ b/src/lightning_trainable/launcher/utils.py @@ -0,0 +1,45 @@ +from pathlib import Path +from typing import List +from urllib.parse import urlencode +from urllib.request import Request, urlopen + +from yaml import safe_load + + +def parse_config_dict(config_spec: List[Path, str]): + hparams = {} + for arg in config_spec: + if isinstance(arg, Path) or ( + any( + arg.endswith(suffix) for suffix in [".yaml", ".yml", ".json"] + ) and "=" not in arg + ): + # Read multiple entries from .yaml file + with open(arg, "r") as file: + new_hparams = safe_load(file) + else: + # Read single entry from command line + key, value = arg.split("=") + new_hparams = {key: safe_load(value)} + + # Merge in new parameters + for key, value in new_hparams.items(): + hparam_level = hparams + key_path = key.split(".") + for key_entry in key_path[:-1]: + hparam_level = hparam_level[key_entry] + hparam_level[key_path[-1]] = value + + +def send_telegram_message(message: str, token: str, chats: List[int]): + try: + url = f"https://api.telegram.org/bot{token}/sendMessage" + for chat_id in chats: + params = { + "chat_id": chat_id, + "text": message + } + request = Request(url, urlencode(params).encode()) + urlopen(request).read().decode() + except FileNotFoundError: + pass From 991c45c51021dbff16465079b077dfe108ac1e3b Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Fri, 24 Mar 2023 10:41:16 +0100 Subject: [PATCH 3/9] Test fit launcher --- src/lightning_trainable/launcher/fit.py | 15 +++++--- src/lightning_trainable/launcher/utils.py | 3 +- tests/test_launcher.py | 42 +++++++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 tests/test_launcher.py diff --git a/src/lightning_trainable/launcher/fit.py b/src/lightning_trainable/launcher/fit.py index e3c6c9b..3e0e8af 100644 --- a/src/lightning_trainable/launcher/fit.py +++ b/src/lightning_trainable/launcher/fit.py @@ -6,7 +6,8 @@ import torch -if __name__ == '__main__': + +def main(args=None): parser = ArgumentParser() parser.add_argument("--pycharm-debug", default=False, type=int, help="Port of PyCharm remote debugger to connect to.") @@ -18,7 +19,7 @@ help="Experiment data will be stored `log_dir`'") parser.add_argument("config_args", type=str, nargs="*", help="") - args = parser.parse_args() + args = parser.parse_args(args) if args.pycharm_debug: import pydevd_pycharm @@ -33,7 +34,7 @@ torch.set_num_threads(num_threads) # Load the model - module_name, model_name = hparams.pop("model") + module_name, model_name = hparams.pop("model").rsplit(".", 1) module = import_module(module_name) model = getattr(module, model_name)(hparams=hparams) @@ -43,7 +44,7 @@ save_dir="lightning_logs", name=args.name ) - elif args.save_dir is not None: + elif args.log_dir is not None: save_path = Path(args.save_dir) logger_kwargs = dict( version=save_path.name, @@ -54,4 +55,8 @@ logger_kwargs = dict() # Fit the model - model.fit(logger_kwargs=logger_kwargs) + return model.fit(logger_kwargs=logger_kwargs) + + +if __name__ == '__main__': + main() diff --git a/src/lightning_trainable/launcher/utils.py b/src/lightning_trainable/launcher/utils.py index 03bd158..ce0f44d 100644 --- a/src/lightning_trainable/launcher/utils.py +++ b/src/lightning_trainable/launcher/utils.py @@ -6,7 +6,7 @@ from yaml import safe_load -def parse_config_dict(config_spec: List[Path, str]): +def parse_config_dict(config_spec: List[Path | str]): hparams = {} for arg in config_spec: if isinstance(arg, Path) or ( @@ -29,6 +29,7 @@ def parse_config_dict(config_spec: List[Path, str]): for key_entry in key_path[:-1]: hparam_level = hparam_level[key_entry] hparam_level[key_path[-1]] = value + return hparams def send_telegram_message(message: str, token: str, chats: List[int]): diff --git a/tests/test_launcher.py b/tests/test_launcher.py new file mode 100644 index 0000000..071c7a3 --- /dev/null +++ b/tests/test_launcher.py @@ -0,0 +1,42 @@ +import torch +from lightning_trainable import Trainable, TrainableHParams +from lightning_trainable.launcher.fit import main +from torch.utils.data import TensorDataset + + +class BasicTrainableHParams(TrainableHParams): + data_set_name: str + + +class BasicTrainable(Trainable): + def __init__(self, hparams: BasicTrainableHParams | dict): + if not isinstance(hparams, BasicTrainableHParams): + hparams = BasicTrainableHParams(**hparams) + + assert hparams.data_set_name == "sine" + x = torch.linspace(-5, 5, 1000)[:, None] + y = torch.sin(x)[:, None] + train_data = TensorDataset(x, y) + + super().__init__(hparams, train_data=train_data) + + self.model = torch.nn.Sequential( + torch.nn.Linear(1, 16), torch.nn.ReLU(), + torch.nn.Linear(16, 1) + ) + + def compute_metrics(self, batch, batch_idx) -> dict: + x, y = batch + return { + "loss": ((self.model(x) - y) ** 2).mean() + } + + +def test_fit_launcher(): + main([ + "model=tests.test_launcher.BasicTrainable", + "batch_size=128", + "max_epochs=1", + "data_set_name='sine'", + "accelerator='cpu'" + ]) From cb0c299fa9701d6aa725fdcbc61b9300dcfa4e7b Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Fri, 24 Mar 2023 11:44:08 +0100 Subject: [PATCH 4/9] Grid launcher --- src/lightning_trainable/launcher/fit.py | 2 +- src/lightning_trainable/launcher/grid.py | 74 +++++++++++++++++------- tests/test_launcher.py | 27 +++++++-- tests/test_launcher_config.yaml | 1 + 4 files changed, 78 insertions(+), 26 deletions(-) create mode 100644 tests/test_launcher_config.yaml diff --git a/src/lightning_trainable/launcher/fit.py b/src/lightning_trainable/launcher/fit.py index 3e0e8af..4f7b840 100644 --- a/src/lightning_trainable/launcher/fit.py +++ b/src/lightning_trainable/launcher/fit.py @@ -29,7 +29,7 @@ def main(args=None): hparams = parse_config_dict(args.config_args) # Set number of threads (potentially move into trainable, but it's a global property) - num_threads = hparams.get("num_threads", None) + num_threads = hparams.pop("num_threads", None) if num_threads is not None: torch.set_num_threads(num_threads) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py index 4b409e8..81ebf0e 100644 --- a/src/lightning_trainable/launcher/grid.py +++ b/src/lightning_trainable/launcher/grid.py @@ -8,7 +8,7 @@ from itertools import product from math import log10 from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Tuple from datetime import timedelta from lightning_trainable.launcher.utils import send_telegram_message @@ -26,7 +26,7 @@ class RunResult: stderr: bytes -class ExperimentLauncher: +class GridLauncher: def __init__(self, telegram_info: Dict = None): self.running_processes: List[Popen] = [] self.telegram_info = telegram_info @@ -35,7 +35,8 @@ def send_message(self, message): if self.telegram_info is not None: send_telegram_message(message, **self.telegram_info) - def run_configuration(self, config, num_threads: int = None, connect_debug: int = None, verbose=False): + def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_threads: int = None, + connect_debug: int = None, verbose=False): """ Runs a single configuration using lightning_trainable.launcher.fit in a subprocess and waits for the result. @@ -45,13 +46,17 @@ def run_configuration(self, config, num_threads: int = None, connect_debug: int arguments.append("--pycharm-debug") arguments.append(str(connect_debug)) - all_config = { - "num_threads": num_threads, - **config - } - if len(all_config) > 0: - for key, value in all_config.items(): - arguments.append(f'{key}={dump(value)}') + if num_threads is not None: + config = config + [("num_threads", num_threads)] + if len(config) > 0: + for value in config: + if isinstance(value, Path): + arguments.append(str(value)) + elif isinstance(value, tuple): + key, value = value + if isinstance(value, type): + value = f"{value.__module__}.{value.__name__}" + arguments.append(f'{key}={dump(value)}') out = None if verbose else subprocess.PIPE with Popen(['python', '-m', 'lightning_trainable.launcher.fit', *arguments], @@ -63,19 +68,48 @@ def run_configuration(self, config, num_threads: int = None, connect_debug: int self.running_processes.remove(process) return RunResult(config=config, return_code=process.poll(), stdout=stdout, stderr=stderr) - def grid_spec_to_list(self, config_spec: Dict[str, list]): + def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str, list]]): """ - Converts a grid of name=list[... values] to a list of dict configurations. + Converts a grid of specifications to a list of configurations. + + Each specification can be a list of values to be passed + directly to the script or a tuple of a key and a list of values. + + For example: + >>> grid_launcher = GridLauncher() + >>> grid_launcher.grid_spec_to_list([ + >>> ("model", ["tests.test_launcher.BasicTrainable"]), + >>> ["test_launcher_config.yaml"], + >>> ("num_threads", [1, 2, 4]), + >>> ]) """ - config_keys = list(config_spec.keys()) configs = [] - for config_values in product(*config_spec.values()): - config = dict(zip(config_keys, config_values)) + + fake_keys = set() + + # Create fake keys for non-tuple entries + dict_args = [] + for entry in config_spec: + if isinstance(entry, tuple): + dict_args.append(entry) + else: + fake_key = f"fake_key_{len(fake_keys)}" + fake_keys.add(fake_key) + dict_args.append((fake_key, entry)) + dict_args = dict(dict_args) + + # Create all possible combinations, removing fake keys + config_keys = list(dict_args.keys()) + for config_values in product(*dict_args.values()): + config = [ + value if key in fake_keys else (key, value) + for key, value in zip(config_keys, config_values) + ] configs.append(config) return configs - def start_runs(self, configs: List[List[Path, str]], num_parallel_runs=None, - num_threads=None, connect_debug: int = None, verbose=False): + def start_runs(self, configs: List[List[Path | str]], num_parallel_runs=None, + num_threads=1, connect_debug: int = None, verbose=False): """ Starts a number of runs in parallel and returns the futures. """ @@ -93,8 +127,8 @@ def start_runs(self, configs: List[List[Path, str]], num_parallel_runs=None, return pool, futures def run_configs_and_wait(self, - configs: List[List[Path, str]], num_parallel_runs=None, - num_threads=None, connect_debug: int = None, verbose=False) -> List[RunResult]: + configs: List[List[Path | str]], num_parallel_runs=None, + num_threads=1, connect_debug: int = None, verbose=False) -> List[RunResult]: """ Runs a list of configurations in parallel and waits for the results. """ @@ -132,7 +166,7 @@ def run_configs_and_wait(self, # Print results status_counts = status_count_counter(results) - print(f"Done running {sum([config.count for config in configs])} experiments: {status_counts}") + print(f"Done running {len(configs)} experiments: {status_counts}") if len(set(status_counts) - {0}) > 0: print(f"Total: {sum(value for key, value in status_counts.items() if key != 0)} FAILED!") else: diff --git a/tests/test_launcher.py b/tests/test_launcher.py index 071c7a3..ac60095 100644 --- a/tests/test_launcher.py +++ b/tests/test_launcher.py @@ -1,11 +1,16 @@ +import os +from pathlib import Path + import torch from lightning_trainable import Trainable, TrainableHParams from lightning_trainable.launcher.fit import main from torch.utils.data import TensorDataset +from lightning_trainable.launcher.grid import GridLauncher + class BasicTrainableHParams(TrainableHParams): - data_set_name: str + domain: list class BasicTrainable(Trainable): @@ -13,8 +18,7 @@ def __init__(self, hparams: BasicTrainableHParams | dict): if not isinstance(hparams, BasicTrainableHParams): hparams = BasicTrainableHParams(**hparams) - assert hparams.data_set_name == "sine" - x = torch.linspace(-5, 5, 1000)[:, None] + x = torch.linspace(*hparams.domain, 1000)[:, None] y = torch.sin(x)[:, None] train_data = TensorDataset(x, y) @@ -35,8 +39,21 @@ def compute_metrics(self, batch, batch_idx) -> dict: def test_fit_launcher(): main([ "model=tests.test_launcher.BasicTrainable", - "batch_size=128", + str(Path(__file__).parent / "test_launcher_config.yaml"), "max_epochs=1", - "data_set_name='sine'", + "domain=[-5, 5]", "accelerator='cpu'" ]) + + +def test_grid_launcher(): + launcher = GridLauncher() + config_list = launcher.grid_spec_to_list([ + ("model", ["tests.test_launcher.BasicTrainable"]), + ([Path(__file__).parent / "test_launcher_config.yaml"]), + ("max_epochs", [1]), + ("domain", [[-5, 5], [-3, 3]]), + ("accelerator", ['cpu']) + ]) + results = launcher.run_configs_and_wait(config_list) + print(results[0].stderr.decode()) diff --git a/tests/test_launcher_config.yaml b/tests/test_launcher_config.yaml new file mode 100644 index 0000000..1750891 --- /dev/null +++ b/tests/test_launcher_config.yaml @@ -0,0 +1 @@ +batch_size: 128 From 837e5545b47a34b5aa1dec9107de6c83615f66a4 Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Mon, 27 Mar 2023 09:08:54 +0200 Subject: [PATCH 5/9] Launcher: --name formatting using str.format(hparams) --- src/lightning_trainable/launcher/fit.py | 14 ++++++++++---- src/lightning_trainable/launcher/grid.py | 14 +++++++++----- tests/test_launcher.py | 9 +++++---- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/lightning_trainable/launcher/fit.py b/src/lightning_trainable/launcher/fit.py index 4f7b840..ca7584e 100644 --- a/src/lightning_trainable/launcher/fit.py +++ b/src/lightning_trainable/launcher/fit.py @@ -34,15 +34,16 @@ def main(args=None): torch.set_num_threads(num_threads) # Load the model - module_name, model_name = hparams.pop("model").rsplit(".", 1) - module = import_module(module_name) - model = getattr(module, model_name)(hparams=hparams) + module_name, model_name = hparams["model"].rsplit(".", 1) # Log path if args.name is not None: logger_kwargs = dict( save_dir="lightning_logs", - name=args.name + name=args.name.format( + model_name=model_name, + **hparams + ) ) elif args.log_dir is not None: save_path = Path(args.save_dir) @@ -54,6 +55,11 @@ def main(args=None): else: logger_kwargs = dict() + # No "model" hparam + del hparams["model"] + module = import_module(module_name) + model = getattr(module, model_name)(hparams=hparams) + # Fit the model return model.fit(logger_kwargs=logger_kwargs) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py index 81ebf0e..19b4263 100644 --- a/src/lightning_trainable/launcher/grid.py +++ b/src/lightning_trainable/launcher/grid.py @@ -36,7 +36,7 @@ def send_message(self, message): send_telegram_message(message, **self.telegram_info) def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_threads: int = None, - connect_debug: int = None, verbose=False): + connect_debug: int = None, verbose=False, cli_args=None): """ Runs a single configuration using lightning_trainable.launcher.fit in a subprocess and waits for the result. @@ -45,6 +45,8 @@ def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_t if connect_debug is not None: arguments.append("--pycharm-debug") arguments.append(str(connect_debug)) + if cli_args is not None: + arguments.extend(cli_args) if num_threads is not None: config = config + [("num_threads", num_threads)] @@ -109,7 +111,7 @@ def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str return configs def start_runs(self, configs: List[List[Path | str]], num_parallel_runs=None, - num_threads=1, connect_debug: int = None, verbose=False): + num_threads=1, connect_debug: int = None, verbose=False, cli_args=None): """ Starts a number of runs in parallel and returns the futures. """ @@ -122,20 +124,22 @@ def start_runs(self, configs: List[List[Path | str]], num_parallel_runs=None, futures.append(pool.submit( self.run_configuration, config=config, num_threads=num_threads, - connect_debug=connect_debug, verbose=verbose + connect_debug=connect_debug, verbose=verbose, + cli_args=cli_args )) return pool, futures def run_configs_and_wait(self, configs: List[List[Path | str]], num_parallel_runs=None, - num_threads=1, connect_debug: int = None, verbose=False) -> List[RunResult]: + num_threads=1, connect_debug: int = None, verbose=False, + cli_args=None) -> List[RunResult]: """ Runs a list of configurations in parallel and waits for the results. """ pool, futures = self.start_runs( configs, num_parallel_runs=num_parallel_runs, num_threads=num_threads, - connect_debug=connect_debug, verbose=verbose + connect_debug=connect_debug, verbose=verbose, cli_args=cli_args ) interrupted_count = 0 while True: diff --git a/tests/test_launcher.py b/tests/test_launcher.py index ac60095..4a9b52d 100644 --- a/tests/test_launcher.py +++ b/tests/test_launcher.py @@ -6,7 +6,7 @@ from lightning_trainable.launcher.fit import main from torch.utils.data import TensorDataset -from lightning_trainable.launcher.grid import GridLauncher +from lightning_trainable.launcher.grid import GridLauncher, status_count_counter class BasicTrainableHParams(TrainableHParams): @@ -42,7 +42,8 @@ def test_fit_launcher(): str(Path(__file__).parent / "test_launcher_config.yaml"), "max_epochs=1", "domain=[-5, 5]", - "accelerator='cpu'" + "accelerator='cpu'", + "--name", "{model_name};{max_epochs}" ]) @@ -55,5 +56,5 @@ def test_grid_launcher(): ("domain", [[-5, 5], [-3, 3]]), ("accelerator", ['cpu']) ]) - results = launcher.run_configs_and_wait(config_list) - print(results[0].stderr.decode()) + results = launcher.run_configs_and_wait(config_list, cli_args=["--name", "{model_name};{max_epochs}"]) + assert status_count_counter(results) == {0: 2} From d73585582ffecabc03ba910a43cc7a113753a979 Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Mon, 27 Mar 2023 09:13:39 +0200 Subject: [PATCH 6/9] Fix multiline code snipped in docstring --- src/lightning_trainable/launcher/grid.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py index 19b4263..cdfff17 100644 --- a/src/lightning_trainable/launcher/grid.py +++ b/src/lightning_trainable/launcher/grid.py @@ -79,10 +79,10 @@ def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str For example: >>> grid_launcher = GridLauncher() - >>> grid_launcher.grid_spec_to_list([ - >>> ("model", ["tests.test_launcher.BasicTrainable"]), - >>> ["test_launcher_config.yaml"], - >>> ("num_threads", [1, 2, 4]), + >>> grid_launcher.grid_spec_to_list([ \ + >>> ("model", ["tests.test_launcher.BasicTrainable"]), \ + >>> ["test_launcher_config.yaml"], \ + >>> ("num_threads", [1, 2, 4]), \ >>> ]) """ configs = [] From ae9063af0f818dd041b41bbdff87072ed7f19808 Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Mon, 27 Mar 2023 09:17:20 +0200 Subject: [PATCH 7/9] Fix multiline code snipped in docstring (second attempt) --- src/lightning_trainable/launcher/grid.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py index cdfff17..61a55c7 100644 --- a/src/lightning_trainable/launcher/grid.py +++ b/src/lightning_trainable/launcher/grid.py @@ -80,10 +80,10 @@ def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str For example: >>> grid_launcher = GridLauncher() >>> grid_launcher.grid_spec_to_list([ \ - >>> ("model", ["tests.test_launcher.BasicTrainable"]), \ - >>> ["test_launcher_config.yaml"], \ - >>> ("num_threads", [1, 2, 4]), \ - >>> ]) + ("model", ["tests.test_launcher.BasicTrainable"]), \ + ["test_launcher_config.yaml"], \ + ("num_threads", [1, 2, 4]), \ + ]) """ configs = [] From c440424a58e797ab64723cb03c02492741ba7831 Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Mon, 27 Mar 2023 09:20:45 +0200 Subject: [PATCH 8/9] Remove example in doc for now --- src/lightning_trainable/launcher/grid.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py index 61a55c7..965302c 100644 --- a/src/lightning_trainable/launcher/grid.py +++ b/src/lightning_trainable/launcher/grid.py @@ -78,12 +78,12 @@ def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str directly to the script or a tuple of a key and a list of values. For example: - >>> grid_launcher = GridLauncher() - >>> grid_launcher.grid_spec_to_list([ \ - ("model", ["tests.test_launcher.BasicTrainable"]), \ - ["test_launcher_config.yaml"], \ - ("num_threads", [1, 2, 4]), \ - ]) + grid_launcher = GridLauncher() + grid_launcher.grid_spec_to_list([ + ("model", ["tests.test_launcher.BasicTrainable"]), + ["test_launcher_config.yaml"], + ("num_threads", [1, 2, 4]), + ]) """ configs = [] From 27c03621642e679ee1caf52652bedb880914b2ee Mon Sep 17 00:00:00 2001 From: Felix Draxler Date: Mon, 27 Mar 2023 09:27:46 +0200 Subject: [PATCH 9/9] Filter os.setpgrp call to non-Windows --- src/lightning_trainable/launcher/grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning_trainable/launcher/grid.py b/src/lightning_trainable/launcher/grid.py index 965302c..03334a5 100644 --- a/src/lightning_trainable/launcher/grid.py +++ b/src/lightning_trainable/launcher/grid.py @@ -1,4 +1,5 @@ import os +import platform import subprocess from dataclasses import dataclass from subprocess import Popen @@ -64,7 +65,7 @@ def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_t with Popen(['python', '-m', 'lightning_trainable.launcher.fit', *arguments], stdout=out, stderr=out, # Signals to controller are not passed to runner - preexec_fn=os.setpgrp) as process: + preexec_fn=None if platform.system() == "Windows" else os.setpgrp) as process: self.running_processes.append(process) stdout, stderr = process.communicate() self.running_processes.remove(process)