Skip to content

Commit

Permalink
Reduce the number of data loader workers per dataset in pytorch (jpat…
Browse files Browse the repository at this point in the history
…a#262)

* update config

* update parameters
  • Loading branch information
farakiko authored Oct 30, 2023
1 parent b856463 commit 18001a3
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 313 deletions.
58 changes: 12 additions & 46 deletions mlpf/pyg/PFDataset.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from typing import List, Optional
from types import SimpleNamespace
from typing import List, Optional

import tensorflow_datasets as tfds
import torch
import torch.utils.data
from torch import Tensor
import torch_geometric
from torch import Tensor
from torch_geometric.data import Batch, Data


class PFDataset:
"""Builds a DataSource from tensorflow datasets."""

def __init__(self, data_dir, name, split, keys_to_get, pad_3d=True, num_samples=None):
def __init__(self, data_dir, name, split, num_samples=None):
"""
Args
data_dir: path to tensorflow_datasets (e.g. `../data/tensorflow_datasets/`)
name: sample and version (e.g. `clic_edm_ttbar_pf:1.5.0`)
split: "train" or "test
split: "train" or "test" (if "valid" then will use "test")
keys_to_get: any selection of ["X", "ygen", "ycand"] to retrieve
"""
if split == "valid":
split = "test"

builder = tfds.builder(name, data_dir=data_dir)

Expand All @@ -36,50 +38,22 @@ def __init__(self, data_dir, name, split, keys_to_get, pad_3d=True, num_samples=
self.ds.dataset_info.features = tmp.features
self.rep = self.ds.__repr__()

# any selection of ["X", "ygen", "ycand"] to retrieve
self.keys_to_get = keys_to_get

self.pad_3d = pad_3d

if num_samples:
self.ds = torch.utils.data.Subset(self.ds, range(num_samples))

def get_sampler(self):
sampler = torch.utils.data.RandomSampler(self.ds)
return sampler

def get_distributed_sampler(self):
sampler = torch.utils.data.distributed.DistributedSampler(self.ds)
return sampler

def get_loader(self, batch_size, world_size, rank, use_cuda=False, num_workers=0, prefetch_factor=None):
if (num_workers > 0) and (prefetch_factor is None):
prefetch_factor = 2 # default prefetch_factor when num_workers>0

if world_size > 1:
sampler = self.get_distributed_sampler()
else:
sampler = self.get_sampler()

return DataLoader(
self.ds,
batch_size=batch_size,
collate_fn=Collater(self.keys_to_get, pad_3d=self.pad_3d),
sampler=sampler,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=use_cuda,
pin_memory_device="cuda:{}".format(rank) if use_cuda else "",
)

def __len__(self):
return len(self.ds)

def __repr__(self):
return self.rep


class DataLoader(torch.utils.data.DataLoader):
def my_getitem(self, vals):
records = self.data_source.__getitems__(vals)
return [self.dataset_info.features.deserialize_example_np(record, decoders=self.decoders) for record in records]


class PFDataLoader(torch.utils.data.DataLoader):
"""
Copied from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/dataloader.html#DataLoader
because we need to implement our own Collater class to load the tensorflow_datasets (see below).
Expand Down Expand Up @@ -148,14 +122,6 @@ def __call__(self, inputs):
return ret


def my_getitem(self, vals):
# print(
# "reading dataset {}:{} from disk in slice {}, total={}".format(self.dataset_info.name, self.split, vals, len(self))
# )
records = self.data_source.__getitems__(vals)
return [self.dataset_info.features.deserialize_example_np(record, decoders=self.decoders) for record in records]


class InterleavedIterator(object):
"""Will combine DataLoaders of different lengths and batch sizes."""

Expand Down
32 changes: 20 additions & 12 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pickle as pkl
from tempfile import TemporaryDirectory
import time
from typing import Optional
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -11,14 +11,12 @@
import tqdm
from torch import Tensor, nn
from torch.nn import functional as F
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.tensorboard import SummaryWriter

from .logger import _logger
from .utils import unpack_predictions, unpack_target

from torch.profiler import profile, record_function, ProfilerActivity


# Ignore divide by 0 errors
np.seterr(divide="ignore", invalid="ignore")

Expand Down Expand Up @@ -150,7 +148,6 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train):
for itrain, batch in tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc=f"{train_or_valid} loop on rank={rank}"
):

if world_size > 1:
_logger.info(f"Step {itrain} on rank={rank}")

Expand Down Expand Up @@ -256,13 +253,21 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
losses_v = train_and_valid(rank, world_size, model, optimizer, valid_loader, False)

if (rank == 0) or (rank == "cpu"):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state_dict = model.module.state_dict()
else:
model_state_dict = model.state_dict()

torch.save(
{"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()},
# "{outdir}/weights-{epoch:02d}-{val_loss:.6f}.pth".format(
# outdir=outdir, epoch=epoch+1, val_loss=losses_v["Total"]),
f"{outdir}/weights_epoch{epoch}.pth",
)

if losses_v["Total"] < best_val_loss:
best_val_loss = losses_v["Total"]
stale_epochs = 0
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state_dict = model.module.state_dict()
else:
model_state_dict = model.state_dict()

torch.save(
{"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()},
Expand Down Expand Up @@ -347,7 +352,10 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
with open(f"{outdir}/mlpf_losses.pkl", "wb") as f:
pkl.dump(losses, f)

if tensorboard_writer:
tensorboard_writer.flush()
if tensorboard_writer:
tensorboard_writer.flush()

if world_size > 1:
dist.barrier()

_logger.info(f"Done with training. Total training time on device {rank} is {round((time.time() - t0_initial)/60,3)}min")
Loading

0 comments on commit 18001a3

Please sign in to comment.