Skip to content

Commit

Permalink
Merge pull request jpata#12 from erwulff/dev_feb24_flatiron
Browse files Browse the repository at this point in the history
Dev feb24 flatiron
  • Loading branch information
erwulff authored Feb 8, 2024
2 parents 5c8d584 + b8e7f3a commit 08f0572
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 19 deletions.
1 change: 0 additions & 1 deletion mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,6 @@ def raytune(
str(Path(cfg["raytune"]["local_dir"]) / name / "config.yaml"),
) # Copy the config file to the train dir for later reference

ray.tune.ray_trial_executor.DEFAULT_GET_TIMEOUT = 1 * 60 * 60 # Avoid timeout errors
if not local:
ray.init(address="auto")

Expand Down
37 changes: 32 additions & 5 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
X_FEATURES,
save_HPs,
get_lr_schedule,
count_parameters,
)


Expand Down Expand Up @@ -501,9 +502,11 @@ def run(rank, world_size, config, args, outdir, logfile):
if Path(config["load"]).name == "checkpoint.pth":
# the checkpoint is likely from a Ray Train run and we need to step one dir higher up
loaddir = str(Path(config["load"]).parent.parent.parent)
testdir_name = "_" + Path(config["load"]).parent.stem
else:
# the checkpoint is likely from a DDP run and we need to step up one dir less
loaddir = str(Path(config["load"]).parent.parent)
testdir_name = "_" + Path(config["load"]).stem

with open(f"{loaddir}/model_kwargs.pkl", "rb") as f:
model_kwargs = pkl.load(f)
Expand Down Expand Up @@ -552,8 +555,14 @@ def run(rank, world_size, config, args, outdir, logfile):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

trainable_params, nontrainable_params, table = count_parameters(model)

if (rank == 0) or (rank == "cpu"):
_logger.info(model)
_logger.info(f"Trainable parameters: {trainable_params}")
_logger.info(f"Non-trainable parameters: {nontrainable_params}")
_logger.info(f"Total parameters: {trainable_params + nontrainable_params}")
_logger.info(table.to_string(index=False))

if args.train:
if (rank == 0) or (rank == "cpu"):
Expand All @@ -571,6 +580,9 @@ def run(rank, world_size, config, args, outdir, logfile):
comet_experiment.log_parameter("rank", rank)
comet_experiment.log_parameters(config, prefix="config:")
comet_experiment.set_model_graph(model)
comet_experiment.log_parameter(trainable_params, "trainable_params")
comet_experiment.log_parameter(nontrainable_params, "nontrainable_params")
comet_experiment.log_parameter(trainable_params + nontrainable_params, "total_trainable_params")
comet_experiment.log_code("mlpf/pyg/training.py")
comet_experiment.log_code("mlpf/pyg_pipeline.py")
# save overridden config then log to comet
Expand Down Expand Up @@ -620,7 +632,12 @@ def run(rank, world_size, config, args, outdir, logfile):
assert args.train, "Please train a model before testing, or load a model with --load"
assert outdir is not None, "Error: no outdir to evaluate model from"
else:
outdir = str(Path(config["load"]).parent.parent)
if Path(config["load"]).name == "checkpoint.pth":
# the checkpoint is likely from a Ray Train run and we need to step one dir higher up
outdir = str(Path(config["load"]).parent.parent.parent)
else:
# the checkpoint is likely from a DDP run and we need to step up one dir less
outdir = str(Path(config["load"]).parent.parent)

for type_ in config["test_dataset"][config["dataset"]]: # will be "physical", "gun"
batch_size = config["test_dataset"][config["dataset"]][type_]["batch_size"] * config["gpu_batch_multiplier"]
Expand Down Expand Up @@ -766,9 +783,6 @@ def train_ray_trial(config, args, outdir=None):
world_rank = ray.train.get_context().get_world_rank()
world_size = ray.train.get_context().get_world_size()

# keep writing the logs
_configLogger("mlpf", filename=f"{outdir}/train.log")

model_kwargs = {
"input_dim": len(X_FEATURES[config["dataset"]]),
"num_classes": len(CLASS_LABELS[config["dataset"]]),
Expand All @@ -781,8 +795,14 @@ def train_ray_trial(config, args, outdir=None):
model = ray.train.torch.prepare_model(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

trainable_params, nontrainable_params, table = count_parameters(model)

if (rank == 0) or (rank == "cpu"):
_logger.info(model)
_logger.info(f"Trainable parameters: {trainable_params}")
_logger.info(f"Non-trainable parameters: {nontrainable_params}")
_logger.info(f"Total parameters: {trainable_params + nontrainable_params}")
_logger.info(table)

if (rank == 0) or (rank == "cpu"):
save_HPs(args, model, model_kwargs, outdir) # save model_kwargs and hyperparameters
Expand All @@ -802,6 +822,9 @@ def train_ray_trial(config, args, outdir=None):
comet_experiment.log_parameter("world_rank", world_rank)
comet_experiment.log_parameters(config, prefix="config:")
comet_experiment.set_model_graph(model)
comet_experiment.log_parameter(trainable_params, "trainable_params")
comet_experiment.log_parameter(nontrainable_params, "nontrainable_params")
comet_experiment.log_parameter(trainable_params + nontrainable_params, "total_trainable_params")
comet_experiment.log_code(str(Path(outdir).parent.parent / "mlpf/pyg/training.py"))
comet_experiment.log_code(str(Path(outdir).parent.parent / "mlpf/pyg_pipeline.py"))
comet_experiment.log_code(str(Path(outdir).parent.parent / "mlpf/raytune/pt_search_space.py"))
Expand Down Expand Up @@ -955,7 +978,11 @@ def run_hpo(config, args):
yaml.dump(config, file)

if not args.local:
ray.init(address="auto")
ray.init(
address=os.environ["ip_head"],
_node_ip_address=os.environ["head_node_ip"],
_temp_dir="/mnt/ceph/users/ewulff/tmp_ray",
)

sched = get_raytune_schedule(config["raytune"])
search_alg = get_raytune_search_alg(config["raytune"])
Expand Down
30 changes: 30 additions & 0 deletions mlpf/pyg/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import pickle as pkl

import pandas as pd
import torch
import torch.utils.data
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR, ConstantLR
Expand Down Expand Up @@ -277,3 +278,32 @@ def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=-
else:
raise ValueError("Supported values for lr_schedule are 'constant', 'onecycle' and 'cosinedecay'.")
return lr_schedule


def count_parameters(model):
table = pd.DataFrame(columns=["Modules", "Trainable params", "Non-tranable params"])
trainable_params = 0
nontrainable_params = 0
for ii, (name, parameter) in enumerate(model.named_parameters()):
params = parameter.numel()
if not parameter.requires_grad:
table = pd.concat(
[
table,
pd.DataFrame(
{"Modules": name, "Trainable Parameters": "-", "Non-tranable Parameters": params}, index=[ii]
),
]
)
nontrainable_params += params
else:
table = pd.concat(
[
table,
pd.DataFrame(
{"Modules": name, "Trainable Parameters": params, "Non-tranable Parameters": "-"}, index=[ii]
),
]
)
trainable_params += params
return trainable_params, nontrainable_params, table
13 changes: 0 additions & 13 deletions mlpf/raytune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search.bohb import TuneBOHB
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.search.nevergrad import NevergradSearch
from ray.tune.search.skopt import SkOptSearch

# from ray.tune.search.hebo import HEBOSearch # HEBO is not yet supported
Expand Down Expand Up @@ -60,18 +59,6 @@ def get_raytune_search_alg(raytune_cfg, seeds=False):
mode=raytune_cfg["default_mode"],
convert_to_python=True,
)
if raytune_cfg["search_alg"] == "nevergrad":
print("INFO: Using bayesian optimization from nevergrad")
import nevergrad as ng

return NevergradSearch(
optimizer=ng.optimizers.BayesOptim(
pca=False,
init_budget=raytune_cfg["nevergrad"]["n_random_steps"],
),
metric=raytune_cfg["default_metric"],
mode=raytune_cfg["default_mode"],
)
# HEBO is not yet supported
# if (raytune_cfg["search_alg"] == "hebo") or (raytune_cfg["search_alg"] == "HEBO"):
# print("Using HEBOSearch")
Expand Down

0 comments on commit 08f0572

Please sign in to comment.