-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from LarsKue/launcher
Collection of launcher scripts
- Loading branch information
Showing
6 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from argparse import ArgumentParser | ||
from importlib import import_module | ||
from pathlib import Path | ||
|
||
from lightning_trainable.launcher.utils import parse_config_dict | ||
|
||
import torch | ||
|
||
|
||
def main(args=None): | ||
parser = ArgumentParser() | ||
parser.add_argument("--pycharm-debug", default=False, type=int, | ||
help="Port of PyCharm remote debugger to connect to.") | ||
log_dir_group = parser.add_mutually_exclusive_group() | ||
log_dir_group.add_argument("--name", type=str, | ||
help="Name of experiment. Experiment data will be stored in " | ||
"lightning_logs/`name`/version_X") | ||
log_dir_group.add_argument("--log-dir", type=str, | ||
help="Experiment data will be stored `log_dir`'") | ||
parser.add_argument("config_args", type=str, nargs="*", | ||
help="") | ||
args = parser.parse_args(args) | ||
|
||
if args.pycharm_debug: | ||
import pydevd_pycharm | ||
|
||
pydevd_pycharm.settrace('localhost', port=args.pycharm_debug, stdoutToServer=True, stderrToServer=True) | ||
|
||
hparams = parse_config_dict(args.config_args) | ||
|
||
# Set number of threads (potentially move into trainable, but it's a global property) | ||
num_threads = hparams.pop("num_threads", None) | ||
if num_threads is not None: | ||
torch.set_num_threads(num_threads) | ||
|
||
# Load the model | ||
module_name, model_name = hparams["model"].rsplit(".", 1) | ||
|
||
# Log path | ||
if args.name is not None: | ||
logger_kwargs = dict( | ||
save_dir="lightning_logs", | ||
name=args.name.format( | ||
model_name=model_name, | ||
**hparams | ||
) | ||
) | ||
elif args.log_dir is not None: | ||
save_path = Path(args.save_dir) | ||
logger_kwargs = dict( | ||
version=save_path.name, | ||
experiment_name=save_path.parent.name, | ||
save_dir=save_path.parent.parent | ||
) | ||
else: | ||
logger_kwargs = dict() | ||
|
||
# No "model" hparam | ||
del hparams["model"] | ||
module = import_module(module_name) | ||
model = getattr(module, model_name)(hparams=hparams) | ||
|
||
# Fit the model | ||
return model.fit(logger_kwargs=logger_kwargs) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import os | ||
import platform | ||
import subprocess | ||
from dataclasses import dataclass | ||
from subprocess import Popen | ||
from collections import namedtuple, Counter | ||
from concurrent.futures import as_completed | ||
from concurrent.futures.thread import ThreadPoolExecutor | ||
from itertools import product | ||
from math import log10 | ||
from pathlib import Path | ||
from typing import Dict, List, Tuple | ||
from datetime import timedelta | ||
|
||
from lightning_trainable.launcher.utils import send_telegram_message | ||
from tqdm import tqdm | ||
from yaml import dump | ||
|
||
ConfigWithCount = namedtuple("ConfigWithCount", ["config", "count"]) | ||
|
||
|
||
@dataclass | ||
class RunResult: | ||
config: Dict | ||
return_code: int | str | ||
stdout: bytes | ||
stderr: bytes | ||
|
||
|
||
class GridLauncher: | ||
def __init__(self, telegram_info: Dict = None): | ||
self.running_processes: List[Popen] = [] | ||
self.telegram_info = telegram_info | ||
|
||
def send_message(self, message): | ||
if self.telegram_info is not None: | ||
send_telegram_message(message, **self.telegram_info) | ||
|
||
def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_threads: int = None, | ||
connect_debug: int = None, verbose=False, cli_args=None): | ||
""" | ||
Runs a single configuration using lightning_trainable.launcher.fit | ||
in a subprocess and waits for the result. | ||
""" | ||
arguments = [] | ||
if connect_debug is not None: | ||
arguments.append("--pycharm-debug") | ||
arguments.append(str(connect_debug)) | ||
if cli_args is not None: | ||
arguments.extend(cli_args) | ||
|
||
if num_threads is not None: | ||
config = config + [("num_threads", num_threads)] | ||
if len(config) > 0: | ||
for value in config: | ||
if isinstance(value, Path): | ||
arguments.append(str(value)) | ||
elif isinstance(value, tuple): | ||
key, value = value | ||
if isinstance(value, type): | ||
value = f"{value.__module__}.{value.__name__}" | ||
arguments.append(f'{key}={dump(value)}') | ||
|
||
out = None if verbose else subprocess.PIPE | ||
with Popen(['python', '-m', 'lightning_trainable.launcher.fit', *arguments], | ||
stdout=out, stderr=out, | ||
# Signals to controller are not passed to runner | ||
preexec_fn=None if platform.system() == "Windows" else os.setpgrp) as process: | ||
self.running_processes.append(process) | ||
stdout, stderr = process.communicate() | ||
self.running_processes.remove(process) | ||
return RunResult(config=config, return_code=process.poll(), stdout=stdout, stderr=stderr) | ||
|
||
def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str, list]]): | ||
""" | ||
Converts a grid of specifications to a list of configurations. | ||
Each specification can be a list of values to be passed | ||
directly to the script or a tuple of a key and a list of values. | ||
For example: | ||
grid_launcher = GridLauncher() | ||
grid_launcher.grid_spec_to_list([ | ||
("model", ["tests.test_launcher.BasicTrainable"]), | ||
["test_launcher_config.yaml"], | ||
("num_threads", [1, 2, 4]), | ||
]) | ||
""" | ||
configs = [] | ||
|
||
fake_keys = set() | ||
|
||
# Create fake keys for non-tuple entries | ||
dict_args = [] | ||
for entry in config_spec: | ||
if isinstance(entry, tuple): | ||
dict_args.append(entry) | ||
else: | ||
fake_key = f"fake_key_{len(fake_keys)}" | ||
fake_keys.add(fake_key) | ||
dict_args.append((fake_key, entry)) | ||
dict_args = dict(dict_args) | ||
|
||
# Create all possible combinations, removing fake keys | ||
config_keys = list(dict_args.keys()) | ||
for config_values in product(*dict_args.values()): | ||
config = [ | ||
value if key in fake_keys else (key, value) | ||
for key, value in zip(config_keys, config_values) | ||
] | ||
configs.append(config) | ||
return configs | ||
|
||
def start_runs(self, configs: List[List[Path | str]], num_parallel_runs=None, | ||
num_threads=1, connect_debug: int = None, verbose=False, cli_args=None): | ||
""" | ||
Starts a number of runs in parallel and returns the futures. | ||
""" | ||
if num_parallel_runs is None: | ||
num_parallel_runs = max(1, os.cpu_count() // num_threads - 1) | ||
|
||
pool = ThreadPoolExecutor(num_parallel_runs) | ||
futures = [] | ||
for config in configs: | ||
futures.append(pool.submit( | ||
self.run_configuration, | ||
config=config, num_threads=num_threads, | ||
connect_debug=connect_debug, verbose=verbose, | ||
cli_args=cli_args | ||
)) | ||
return pool, futures | ||
|
||
def run_configs_and_wait(self, | ||
configs: List[List[Path | str]], num_parallel_runs=None, | ||
num_threads=1, connect_debug: int = None, verbose=False, | ||
cli_args=None) -> List[RunResult]: | ||
""" | ||
Runs a list of configurations in parallel and waits for the results. | ||
""" | ||
pool, futures = self.start_runs( | ||
configs, | ||
num_parallel_runs=num_parallel_runs, num_threads=num_threads, | ||
connect_debug=connect_debug, verbose=verbose, cli_args=cli_args | ||
) | ||
interrupted_count = 0 | ||
while True: | ||
try: | ||
results = self.fetch_results(futures) | ||
break | ||
except KeyboardInterrupt: | ||
interrupted_count += 1 | ||
if interrupted_count == 1: | ||
# Cancel future runs | ||
pool.shutdown(wait=False, cancel_futures=True) | ||
# Pool shutdown does not mark futures as_completed | ||
# https://github.com/python/cpython/issues/87893 | ||
for f in tqdm(futures, desc="Cancelling future runs"): | ||
if f.cancelled(): | ||
f.set_running_or_notify_cancel() | ||
print("Stopped all pending experiments.") | ||
print("Hit Ctrl-C again to cancel running experiments.") | ||
elif interrupted_count == 2: | ||
# Cancel current runs | ||
for process in tqdm(self.running_processes, desc="Killing processes"): | ||
process.kill() | ||
print("Stopped all running experiments.") | ||
if interrupted_count > 2: | ||
raise KeyboardInterrupt | ||
# Wait for remaining processes | ||
pool.shutdown(wait=True) | ||
|
||
# Print results | ||
status_counts = status_count_counter(results) | ||
print(f"Done running {len(configs)} experiments: {status_counts}") | ||
if len(set(status_counts) - {0}) > 0: | ||
print(f"Total: {sum(value for key, value in status_counts.items() if key != 0)} FAILED!") | ||
else: | ||
print("All succeeded :D") | ||
self.send_message(f"Launcher done: {status_counts}") | ||
return results | ||
|
||
def fetch_results(self, futures, timeout=None): | ||
""" | ||
Fetches the results from a list of futures. | ||
""" | ||
last_elapsed = 60 | ||
results = [] | ||
with tqdm(as_completed(futures, timeout=timeout), total=len(futures), smoothing=0) as pbar: | ||
for future in pbar: | ||
if future.cancelled(): | ||
result = RunResult(None, "cancelled", None, None) | ||
else: | ||
result: RunResult = future.result() | ||
results.append(result) | ||
|
||
status_counts = status_count_counter(results) | ||
result_code = result.return_code | ||
|
||
elapsed = pbar.format_dict["elapsed"] | ||
elapsed_delta = timedelta(seconds=elapsed) | ||
if result_code != 0 and status_counts[result_code] == 10 ** int( | ||
log10(status_counts[result_code])): | ||
self.send_message( | ||
f"Code {result_code}: {status_counts[result_code]} " | ||
f"failed after {elapsed_delta}." | ||
) | ||
elif result_code == 0 and elapsed > last_elapsed * 2: | ||
self.send_message( | ||
f"{status_counts[result_code]} succeeded after {elapsed_delta}." | ||
) | ||
last_elapsed = elapsed | ||
pbar.set_description(str(status_counts)) | ||
return results | ||
|
||
|
||
def status_count_counter(results: List[RunResult]) -> Counter: | ||
return Counter(result.return_code for result in results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pathlib import Path | ||
from typing import List | ||
from urllib.parse import urlencode | ||
from urllib.request import Request, urlopen | ||
|
||
from yaml import safe_load | ||
|
||
|
||
def parse_config_dict(config_spec: List[Path | str]): | ||
hparams = {} | ||
for arg in config_spec: | ||
if isinstance(arg, Path) or ( | ||
any( | ||
arg.endswith(suffix) for suffix in [".yaml", ".yml", ".json"] | ||
) and "=" not in arg | ||
): | ||
# Read multiple entries from .yaml file | ||
with open(arg, "r") as file: | ||
new_hparams = safe_load(file) | ||
else: | ||
# Read single entry from command line | ||
key, value = arg.split("=") | ||
new_hparams = {key: safe_load(value)} | ||
|
||
# Merge in new parameters | ||
for key, value in new_hparams.items(): | ||
hparam_level = hparams | ||
key_path = key.split(".") | ||
for key_entry in key_path[:-1]: | ||
hparam_level = hparam_level[key_entry] | ||
hparam_level[key_path[-1]] = value | ||
return hparams | ||
|
||
|
||
def send_telegram_message(message: str, token: str, chats: List[int]): | ||
try: | ||
url = f"https://api.telegram.org/bot{token}/sendMessage" | ||
for chat_id in chats: | ||
params = { | ||
"chat_id": chat_id, | ||
"text": message | ||
} | ||
request = Request(url, urlencode(params).encode()) | ||
urlopen(request).read().decode() | ||
except FileNotFoundError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from lightning_trainable import Trainable, TrainableHParams | ||
from lightning_trainable.launcher.fit import main | ||
from torch.utils.data import TensorDataset | ||
|
||
from lightning_trainable.launcher.grid import GridLauncher, status_count_counter | ||
|
||
|
||
class BasicTrainableHParams(TrainableHParams): | ||
domain: list | ||
|
||
|
||
class BasicTrainable(Trainable): | ||
def __init__(self, hparams: BasicTrainableHParams | dict): | ||
if not isinstance(hparams, BasicTrainableHParams): | ||
hparams = BasicTrainableHParams(**hparams) | ||
|
||
x = torch.linspace(*hparams.domain, 1000)[:, None] | ||
y = torch.sin(x)[:, None] | ||
train_data = TensorDataset(x, y) | ||
|
||
super().__init__(hparams, train_data=train_data) | ||
|
||
self.model = torch.nn.Sequential( | ||
torch.nn.Linear(1, 16), torch.nn.ReLU(), | ||
torch.nn.Linear(16, 1) | ||
) | ||
|
||
def compute_metrics(self, batch, batch_idx) -> dict: | ||
x, y = batch | ||
return { | ||
"loss": ((self.model(x) - y) ** 2).mean() | ||
} | ||
|
||
|
||
def test_fit_launcher(): | ||
main([ | ||
"model=tests.test_launcher.BasicTrainable", | ||
str(Path(__file__).parent / "test_launcher_config.yaml"), | ||
"max_epochs=1", | ||
"domain=[-5, 5]", | ||
"accelerator='cpu'", | ||
"--name", "{model_name};{max_epochs}" | ||
]) | ||
|
||
|
||
def test_grid_launcher(): | ||
launcher = GridLauncher() | ||
config_list = launcher.grid_spec_to_list([ | ||
("model", ["tests.test_launcher.BasicTrainable"]), | ||
([Path(__file__).parent / "test_launcher_config.yaml"]), | ||
("max_epochs", [1]), | ||
("domain", [[-5, 5], [-3, 3]]), | ||
("accelerator", ['cpu']) | ||
]) | ||
results = launcher.run_configs_and_wait(config_list, cli_args=["--name", "{model_name};{max_epochs}"]) | ||
assert status_count_counter(results) == {0: 2} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
batch_size: 128 |