Skip to content

Commit

Permalink
Merge pull request #4 from LarsKue/launcher
Browse files Browse the repository at this point in the history
Collection of launcher scripts
  • Loading branch information
fdraxler authored Mar 27, 2023
2 parents ad125e9 + 27c0362 commit a3c6ed7
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 0 deletions.
Empty file.
68 changes: 68 additions & 0 deletions src/lightning_trainable/launcher/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from argparse import ArgumentParser
from importlib import import_module
from pathlib import Path

from lightning_trainable.launcher.utils import parse_config_dict

import torch


def main(args=None):
parser = ArgumentParser()
parser.add_argument("--pycharm-debug", default=False, type=int,
help="Port of PyCharm remote debugger to connect to.")
log_dir_group = parser.add_mutually_exclusive_group()
log_dir_group.add_argument("--name", type=str,
help="Name of experiment. Experiment data will be stored in "
"lightning_logs/`name`/version_X")
log_dir_group.add_argument("--log-dir", type=str,
help="Experiment data will be stored `log_dir`'")
parser.add_argument("config_args", type=str, nargs="*",
help="")
args = parser.parse_args(args)

if args.pycharm_debug:
import pydevd_pycharm

pydevd_pycharm.settrace('localhost', port=args.pycharm_debug, stdoutToServer=True, stderrToServer=True)

hparams = parse_config_dict(args.config_args)

# Set number of threads (potentially move into trainable, but it's a global property)
num_threads = hparams.pop("num_threads", None)
if num_threads is not None:
torch.set_num_threads(num_threads)

# Load the model
module_name, model_name = hparams["model"].rsplit(".", 1)

# Log path
if args.name is not None:
logger_kwargs = dict(
save_dir="lightning_logs",
name=args.name.format(
model_name=model_name,
**hparams
)
)
elif args.log_dir is not None:
save_path = Path(args.save_dir)
logger_kwargs = dict(
version=save_path.name,
experiment_name=save_path.parent.name,
save_dir=save_path.parent.parent
)
else:
logger_kwargs = dict()

# No "model" hparam
del hparams["model"]
module = import_module(module_name)
model = getattr(module, model_name)(hparams=hparams)

# Fit the model
return model.fit(logger_kwargs=logger_kwargs)


if __name__ == '__main__':
main()
217 changes: 217 additions & 0 deletions src/lightning_trainable/launcher/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import os
import platform
import subprocess
from dataclasses import dataclass
from subprocess import Popen
from collections import namedtuple, Counter
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from itertools import product
from math import log10
from pathlib import Path
from typing import Dict, List, Tuple
from datetime import timedelta

from lightning_trainable.launcher.utils import send_telegram_message
from tqdm import tqdm
from yaml import dump

ConfigWithCount = namedtuple("ConfigWithCount", ["config", "count"])


@dataclass
class RunResult:
config: Dict
return_code: int | str
stdout: bytes
stderr: bytes


class GridLauncher:
def __init__(self, telegram_info: Dict = None):
self.running_processes: List[Popen] = []
self.telegram_info = telegram_info

def send_message(self, message):
if self.telegram_info is not None:
send_telegram_message(message, **self.telegram_info)

def run_configuration(self, config: List[Path | str | Tuple[str, object]], num_threads: int = None,
connect_debug: int = None, verbose=False, cli_args=None):
"""
Runs a single configuration using lightning_trainable.launcher.fit
in a subprocess and waits for the result.
"""
arguments = []
if connect_debug is not None:
arguments.append("--pycharm-debug")
arguments.append(str(connect_debug))
if cli_args is not None:
arguments.extend(cli_args)

if num_threads is not None:
config = config + [("num_threads", num_threads)]
if len(config) > 0:
for value in config:
if isinstance(value, Path):
arguments.append(str(value))
elif isinstance(value, tuple):
key, value = value
if isinstance(value, type):
value = f"{value.__module__}.{value.__name__}"
arguments.append(f'{key}={dump(value)}')

out = None if verbose else subprocess.PIPE
with Popen(['python', '-m', 'lightning_trainable.launcher.fit', *arguments],
stdout=out, stderr=out,
# Signals to controller are not passed to runner
preexec_fn=None if platform.system() == "Windows" else os.setpgrp) as process:
self.running_processes.append(process)
stdout, stderr = process.communicate()
self.running_processes.remove(process)
return RunResult(config=config, return_code=process.poll(), stdout=stdout, stderr=stderr)

def grid_spec_to_list(self, config_spec: Dict[str, list] | List[list | Tuple[str, list]]):
"""
Converts a grid of specifications to a list of configurations.
Each specification can be a list of values to be passed
directly to the script or a tuple of a key and a list of values.
For example:
grid_launcher = GridLauncher()
grid_launcher.grid_spec_to_list([
("model", ["tests.test_launcher.BasicTrainable"]),
["test_launcher_config.yaml"],
("num_threads", [1, 2, 4]),
])
"""
configs = []

fake_keys = set()

# Create fake keys for non-tuple entries
dict_args = []
for entry in config_spec:
if isinstance(entry, tuple):
dict_args.append(entry)
else:
fake_key = f"fake_key_{len(fake_keys)}"
fake_keys.add(fake_key)
dict_args.append((fake_key, entry))
dict_args = dict(dict_args)

# Create all possible combinations, removing fake keys
config_keys = list(dict_args.keys())
for config_values in product(*dict_args.values()):
config = [
value if key in fake_keys else (key, value)
for key, value in zip(config_keys, config_values)
]
configs.append(config)
return configs

def start_runs(self, configs: List[List[Path | str]], num_parallel_runs=None,
num_threads=1, connect_debug: int = None, verbose=False, cli_args=None):
"""
Starts a number of runs in parallel and returns the futures.
"""
if num_parallel_runs is None:
num_parallel_runs = max(1, os.cpu_count() // num_threads - 1)

pool = ThreadPoolExecutor(num_parallel_runs)
futures = []
for config in configs:
futures.append(pool.submit(
self.run_configuration,
config=config, num_threads=num_threads,
connect_debug=connect_debug, verbose=verbose,
cli_args=cli_args
))
return pool, futures

def run_configs_and_wait(self,
configs: List[List[Path | str]], num_parallel_runs=None,
num_threads=1, connect_debug: int = None, verbose=False,
cli_args=None) -> List[RunResult]:
"""
Runs a list of configurations in parallel and waits for the results.
"""
pool, futures = self.start_runs(
configs,
num_parallel_runs=num_parallel_runs, num_threads=num_threads,
connect_debug=connect_debug, verbose=verbose, cli_args=cli_args
)
interrupted_count = 0
while True:
try:
results = self.fetch_results(futures)
break
except KeyboardInterrupt:
interrupted_count += 1
if interrupted_count == 1:
# Cancel future runs
pool.shutdown(wait=False, cancel_futures=True)
# Pool shutdown does not mark futures as_completed
# https://github.com/python/cpython/issues/87893
for f in tqdm(futures, desc="Cancelling future runs"):
if f.cancelled():
f.set_running_or_notify_cancel()
print("Stopped all pending experiments.")
print("Hit Ctrl-C again to cancel running experiments.")
elif interrupted_count == 2:
# Cancel current runs
for process in tqdm(self.running_processes, desc="Killing processes"):
process.kill()
print("Stopped all running experiments.")
if interrupted_count > 2:
raise KeyboardInterrupt
# Wait for remaining processes
pool.shutdown(wait=True)

# Print results
status_counts = status_count_counter(results)
print(f"Done running {len(configs)} experiments: {status_counts}")
if len(set(status_counts) - {0}) > 0:
print(f"Total: {sum(value for key, value in status_counts.items() if key != 0)} FAILED!")
else:
print("All succeeded :D")
self.send_message(f"Launcher done: {status_counts}")
return results

def fetch_results(self, futures, timeout=None):
"""
Fetches the results from a list of futures.
"""
last_elapsed = 60
results = []
with tqdm(as_completed(futures, timeout=timeout), total=len(futures), smoothing=0) as pbar:
for future in pbar:
if future.cancelled():
result = RunResult(None, "cancelled", None, None)
else:
result: RunResult = future.result()
results.append(result)

status_counts = status_count_counter(results)
result_code = result.return_code

elapsed = pbar.format_dict["elapsed"]
elapsed_delta = timedelta(seconds=elapsed)
if result_code != 0 and status_counts[result_code] == 10 ** int(
log10(status_counts[result_code])):
self.send_message(
f"Code {result_code}: {status_counts[result_code]} "
f"failed after {elapsed_delta}."
)
elif result_code == 0 and elapsed > last_elapsed * 2:
self.send_message(
f"{status_counts[result_code]} succeeded after {elapsed_delta}."
)
last_elapsed = elapsed
pbar.set_description(str(status_counts))
return results


def status_count_counter(results: List[RunResult]) -> Counter:
return Counter(result.return_code for result in results)
46 changes: 46 additions & 0 deletions src/lightning_trainable/launcher/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from typing import List
from urllib.parse import urlencode
from urllib.request import Request, urlopen

from yaml import safe_load


def parse_config_dict(config_spec: List[Path | str]):
hparams = {}
for arg in config_spec:
if isinstance(arg, Path) or (
any(
arg.endswith(suffix) for suffix in [".yaml", ".yml", ".json"]
) and "=" not in arg
):
# Read multiple entries from .yaml file
with open(arg, "r") as file:
new_hparams = safe_load(file)
else:
# Read single entry from command line
key, value = arg.split("=")
new_hparams = {key: safe_load(value)}

# Merge in new parameters
for key, value in new_hparams.items():
hparam_level = hparams
key_path = key.split(".")
for key_entry in key_path[:-1]:
hparam_level = hparam_level[key_entry]
hparam_level[key_path[-1]] = value
return hparams


def send_telegram_message(message: str, token: str, chats: List[int]):
try:
url = f"https://api.telegram.org/bot{token}/sendMessage"
for chat_id in chats:
params = {
"chat_id": chat_id,
"text": message
}
request = Request(url, urlencode(params).encode())
urlopen(request).read().decode()
except FileNotFoundError:
pass
60 changes: 60 additions & 0 deletions tests/test_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from pathlib import Path

import torch
from lightning_trainable import Trainable, TrainableHParams
from lightning_trainable.launcher.fit import main
from torch.utils.data import TensorDataset

from lightning_trainable.launcher.grid import GridLauncher, status_count_counter


class BasicTrainableHParams(TrainableHParams):
domain: list


class BasicTrainable(Trainable):
def __init__(self, hparams: BasicTrainableHParams | dict):
if not isinstance(hparams, BasicTrainableHParams):
hparams = BasicTrainableHParams(**hparams)

x = torch.linspace(*hparams.domain, 1000)[:, None]
y = torch.sin(x)[:, None]
train_data = TensorDataset(x, y)

super().__init__(hparams, train_data=train_data)

self.model = torch.nn.Sequential(
torch.nn.Linear(1, 16), torch.nn.ReLU(),
torch.nn.Linear(16, 1)
)

def compute_metrics(self, batch, batch_idx) -> dict:
x, y = batch
return {
"loss": ((self.model(x) - y) ** 2).mean()
}


def test_fit_launcher():
main([
"model=tests.test_launcher.BasicTrainable",
str(Path(__file__).parent / "test_launcher_config.yaml"),
"max_epochs=1",
"domain=[-5, 5]",
"accelerator='cpu'",
"--name", "{model_name};{max_epochs}"
])


def test_grid_launcher():
launcher = GridLauncher()
config_list = launcher.grid_spec_to_list([
("model", ["tests.test_launcher.BasicTrainable"]),
([Path(__file__).parent / "test_launcher_config.yaml"]),
("max_epochs", [1]),
("domain", [[-5, 5], [-3, 3]]),
("accelerator", ['cpu'])
])
results = launcher.run_configs_and_wait(config_list, cli_args=["--name", "{model_name};{max_epochs}"])
assert status_count_counter(results) == {0: 2}
1 change: 1 addition & 0 deletions tests/test_launcher_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
batch_size: 128

0 comments on commit a3c6ed7

Please sign in to comment.