Skip to content

Commit

Permalink
Implement HPO for PyTorch pipeline. (jpata#246)
Browse files Browse the repository at this point in the history
* wip: implement HPO in pytorch pipeline

* fix: bugs after rebase

* chore: code formatting

* fix: minor bug

* fix: typo

* fix: lr casted to str when read from config

* try reducing --ntrain --ntest in tests

* update distbarrier and fix stale pochs (jpata#249)

* change pytorch CI/CD test to use gravnet model

* feat: implemented HPO using Ray Tune

Now able to perform hyperparameter search using random search with
automatic trial launching and Ray-compatbile checkpointing.

Support is still missing for:
- Trial schedulers
- Advanced Ray Tune search algorithms

* fix: flake8 error

* chore: update default config values for pyg

---------

Co-authored-by: Farouk Mokhtar <[email protected]>
  • Loading branch information
erwulff and farakiko authored Oct 25, 2023
1 parent fc2083b commit 370f47f
Show file tree
Hide file tree
Showing 14 changed files with 542 additions and 87 deletions.
2 changes: 1 addition & 1 deletion mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))
elif self.conv_type == "gnn-lsh":
elif self.conv_type == "gnn_lsh":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
Expand Down
58 changes: 48 additions & 10 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pickle as pkl
from tempfile import TemporaryDirectory
import time
from typing import Optional
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -137,7 +139,7 @@ def train(
best_val_loss,
stale_epochs,
patience,
outpath,
outdir,
tensorboard_writer=None,
):
"""
Expand Down Expand Up @@ -238,10 +240,10 @@ def train(

torch.save(
{"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()},
f"{outpath}/best_weights.pth",
f"{outdir}/best_weights.pth",
)
_logger.info(
f"finished {itrain+1}/{len(train_loader)} iterations and saved the model at {outpath}/best_weights.pth" # noqa
f"finished {itrain+1}/{len(train_loader)} iterations and saved the model at {outdir}/best_weights.pth" # noqa
)
stale_epochs = torch.tensor(0, device=rank)
else:
Expand Down Expand Up @@ -278,7 +280,7 @@ def train(
return epoch_loss, valid_loss, best_val_loss, stale_epochs


def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, num_epochs, patience, outpath):
def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, num_epochs, patience, outdir, hpo=False):
"""
Will run a full training by calling train().
Expand All @@ -288,11 +290,11 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
train_loader: a pytorch geometric Dataloader that loads the training data in the form ~ DataBatch(X, ygen, ycands)
valid_loader: a pytorch geometric Dataloader that loads the validation data in the form ~ DataBatch(X, ygen, ycands)
patience: number of stale epochs before stopping the training
outpath: path to store the model weights and training plots
outdir: path to store the model weights and training plots
"""

if (rank == 0) or (rank == "cpu"):
tensorboard_writer = SummaryWriter(f"{outpath}/runs/")
tensorboard_writer = SummaryWriter(f"{outdir}/runs/")
else:
tensorboard_writer = False

Expand All @@ -306,7 +308,23 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
losses["train"][loss], losses["valid"][loss] = [], []

stale_epochs, best_val_loss = torch.tensor(0, device=rank), 99999.9
for epoch in range(num_epochs):
start_epoch = 0

if hpo:
import ray.train as ray_train
from ray.train import Checkpoint

checkpoint = ray_train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_dir = Path(checkpoint_dir)
# TODO: EW, check if map_location should be "cpu" below
model.load_state_dict(torch.load(checkpoint_dir / "model.pt"))
optimizer.load_state_dict(torch.load(checkpoint_dir / "optim.pt"))
start_epoch = torch.load(checkpoint_dir / "extra_state.pt")["epoch"] + 1

for epoch in range(start_epoch, num_epochs):
_logger.info(f"Initiating epoch # {epoch}", color="bold")
t0 = time.time()

Expand All @@ -321,10 +339,30 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
best_val_loss,
stale_epochs,
patience,
outpath,
outdir,
tensorboard_writer,
)

if hpo:
# save model, optimizer and epoch number for HPO-supported checkpointing
if (rank == 0) or (rank == "cpu"):
# Ray automatically syncs the cehckpoint to persistent storage
with TemporaryDirectory() as temp_checkpoint_dir:
temp_checkpoint_dir = Path(temp_checkpoint_dir)
torch.save(model.state_dict(), temp_checkpoint_dir / "model.pt")
torch.save(optimizer.state_dict(), temp_checkpoint_dir / "optim.pt")
torch.save({"epoch": epoch}, temp_checkpoint_dir / "extra_state.pt")

# report metrics and checkpoint to Ray
ray_train.report(
dict(
loss=losses_t["Total"],
val_loss=losses_v["Total"],
epoch=epoch,
),
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
)

if stale_epochs > patience:
break

Expand Down Expand Up @@ -378,10 +416,10 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
ax.set_ylim(0.8 * losses["train"][loss][-1], 1.2 * losses["train"][loss][-1])
ax.legend(title="MLPF", loc="best", title_fontsize=20, fontsize=15)
plt.tight_layout()
plt.savefig(f"{outpath}/mlpf_loss_{loss}.pdf")
plt.savefig(f"{outdir}/mlpf_loss_{loss}.pdf")
plt.close()

with open(f"{outpath}/mlpf_losses.pkl", "wb") as f:
with open(f"{outdir}/mlpf_losses.pkl", "wb") as f:
pkl.dump(losses, f)

_logger.info(f"Done with training. Total training time on device {rank} is {round((time.time() - t0_initial)/60,3)}min")
Loading

0 comments on commit 370f47f

Please sign in to comment.