From af3225cc7bb57e07a20ed10fb489f9995d4c3aca Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 13 Feb 2024 13:42:56 +0100 Subject: [PATCH] enable FlashAttention in pytorch, update to torch 2.2.0 (#292) * implement attention configuration --- .github/workflows/test.yml | 12 +-- mlpf/pyg/PFDataset.py | 23 +++-- mlpf/pyg/inference.py | 2 +- mlpf/pyg/mlpf.py | 28 ++++++- mlpf/pyg/training.py | 116 +++++++++++++++++++------- mlpf/pyg_pipeline.py | 14 ++++ parameters/pytorch/pyg-clic-hits.yaml | 2 + parameters/pytorch/pyg-clic.yaml | 2 + parameters/pytorch/pyg-cms.yaml | 10 ++- parameters/pytorch/pyg-delphes.yaml | 2 + scripts/local_test_pyg.sh | 6 +- scripts/tallinn/a100/pytorch.sh | 15 +++- scripts/tallinn/rtx/pytorch.sh | 2 +- 13 files changed, 176 insertions(+), 58 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 056f30201..05c7978ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,8 +28,8 @@ jobs: python-version: "3.10.12" cache: "pip" - run: pip install -r requirements.txt - - run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - - run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html + - run: pip3 install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + - run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-2.2.0+cpu.html tf-unittests: runs-on: ubuntu-22.04 @@ -101,8 +101,8 @@ jobs: python-version: "3.10.12" cache: "pip" - run: pip install -r requirements.txt - - run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - - run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html + - run: pip3 install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + - run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-2.2.0+cpu.html - run: PYTHONPATH=. python3 -m unittest tests/test_torch_and_tf.py pyg-pipeline: @@ -115,6 +115,6 @@ jobs: python-version: "3.10.12" cache: "pip" - run: pip install -r requirements.txt - - run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - - run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html + - run: pip3 install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + - run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-2.2.0+cpu.html - run: ./scripts/local_test_pyg.sh diff --git a/mlpf/pyg/PFDataset.py b/mlpf/pyg/PFDataset.py index f6651fd0f..bbea414bd 100644 --- a/mlpf/pyg/PFDataset.py +++ b/mlpf/pyg/PFDataset.py @@ -103,15 +103,19 @@ def __init__( ) +def next_power_of_2(x): + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + class Collater: """Based on the Collater found on torch_geometric docs we build our own.""" - def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_bin_size=640, pad_3d=True): + def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_3d=True, pad_power_of_two=True): self.follow_batch = follow_batch self.exclude_keys = exclude_keys self.keys_to_get = keys_to_get - self.pad_bin_size = pad_bin_size self.pad_3d = pad_3d + self.pad_power_of_two = False def __call__(self, inputs): num_samples_in_batch = len(inputs) @@ -129,7 +133,16 @@ def __call__(self, inputs): if not self.pad_3d: return ret else: - ret = {k: torch_geometric.utils.to_dense_batch(getattr(ret, k), ret.batch) for k in elem_keys} + # pad to closest power of two + if self.pad_power_of_two: + sizes = [next_power_of_2(len(b.X)) for b in batch] + max_size = max(sizes) + else: + max_size = None + ret = { + k: torch_geometric.utils.to_dense_batch(getattr(ret, k), ret.batch, max_num_nodes=max_size) + for k in elem_keys + } ret["mask"] = ret["X"][1] @@ -185,7 +198,7 @@ def __len__(self): return len_ -def get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, use_ray): +def get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, pad_power_of_two, use_ray): loaders = {} for split in ["train", "valid"]: # build train, valid dataset and dataloaders loaders[split] = [] @@ -219,7 +232,7 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, use_ loader = PFDataLoader( dataset, batch_size=batch_size, - collate_fn=Collater(["X", "ygen"], pad_3d=pad_3d), + collate_fn=Collater(["X", "ygen"], pad_3d=pad_3d, pad_power_of_two=pad_power_of_two), sampler=sampler, num_workers=config["num_workers"], prefetch_factor=config["prefetch_factor"], diff --git a/mlpf/pyg/inference.py b/mlpf/pyg/inference.py index f395df2fc..6f2e6df86 100644 --- a/mlpf/pyg/inference.py +++ b/mlpf/pyg/inference.py @@ -7,7 +7,6 @@ import mplhep import numpy as np import torch -import torch_geometric import tqdm import vector from jet_utils import build_dummy_array, match_two_jet_collections @@ -22,6 +21,7 @@ plot_particles, plot_sum_energy, ) +import torch_geometric from torch_geometric.data import Batch from .logger import _logger diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index c1409072a..513249a4a 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -4,6 +4,9 @@ from .gnn_lsh import CombinedGraphLayer +from torch.backends.cuda import sdp_kernel +from pyg.logger import _logger + class GravNetLayer(nn.Module): def __init__(self, embedding_dim, space_dimensions, propagate_dimensions, k, dropout): @@ -22,7 +25,7 @@ def forward(self, x, batch_index): class SelfAttentionLayer(nn.Module): - def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1): + def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1, attention_type="efficient"): super(SelfAttentionLayer, self).__init__() self.act = nn.ELU self.mha = torch.nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True) @@ -32,9 +35,20 @@ def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1): nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act() ) self.dropout = torch.nn.Dropout(dropout) + self.attention_type = attention_type + _logger.info("using attention_type={}".format(attention_type)) + self.attn_params = { + "math": {"enable_math": True, "enable_mem_efficient": False, "enable_flash": False}, + "efficient": {"enable_math": False, "enable_mem_efficient": True, "enable_flash": False}, + "flash": {"enable_math": False, "enable_mem_efficient": False, "enable_flash": True}, + } def forward(self, x, mask): - x = self.norm0(x + self.mha(x, x, x, key_padding_mask=mask, need_weights=False)[0]) + # explicitly call the desired attention mechanism + with sdp_kernel(**self.attn_params[self.attention_type]): + mha_out = self.mha(x, x, x, need_weights=False)[0] + + x = self.norm0(x + mha_out) x = self.norm1(x + self.seq(x)) x = self.dropout(x) x = x * (~mask.unsqueeze(-1)) @@ -117,6 +131,7 @@ def __init__( propagate_dimensions=32, space_dimensions=4, conv_type="gravnet", + attention_type="flash", # gnn-lsh specific parameters bin_size=640, max_num_bins=200, @@ -168,8 +183,12 @@ def __init__( self.conv_id = nn.ModuleList() self.conv_reg = nn.ModuleList() for i in range(num_convs): - self.conv_id.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout)) - self.conv_reg.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout)) + self.conv_id.append( + SelfAttentionLayer(embedding_dim, num_heads, width, dropout, attention_type=attention_type) + ) + self.conv_reg.append( + SelfAttentionLayer(embedding_dim, num_heads, width, dropout, attention_type=attention_type) + ) elif self.conv_type == "mamba": self.conv_id = nn.ModuleList() self.conv_reg = nn.ModuleList() @@ -209,6 +228,7 @@ def __init__( # elementwise DNN for node charge regression, classes (-1, 0, 1) self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout) + # @torch.compile def forward(self, X_features, batch_or_mask): embeddings_id, embeddings_reg = [], [] if self.num_convs != 0: diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index 30b8f6071..10f61a109 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -202,11 +202,12 @@ def train_and_valid( model, optimizer, data_loader, - is_train, + is_train=True, lr_schedule=None, comet_experiment=None, comet_step_freq=None, epoch=None, + dtype=torch.float32, ): """ Performs training over a given epoch. Will run a validation step every N_STEPS and after the last training batch. @@ -231,6 +232,8 @@ def train_and_valid( enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch} {train_or_valid} loop on rank={rank}" ) + device_type = "cuda" if isinstance(rank, int) else "cpu" + for itrain, batch in iterator: batch = batch.to(rank, non_blocking=True) @@ -242,20 +245,24 @@ def train_and_valid( conv_type = model.conv_type batchidx_or_mask = batch.batch if conv_type == "gravnet" else batch.mask - if is_train: - ypred = model(batch.X, batchidx_or_mask) - else: - with torch.no_grad(): + + with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"): + if is_train: ypred = model(batch.X, batchidx_or_mask) + else: + with torch.no_grad(): + ypred = model(batch.X, batchidx_or_mask) + ypred = unpack_predictions(ypred) - if is_train: - loss = mlpf_loss(ygen, ypred) - for param in model.parameters(): - param.grad = None - else: - with torch.no_grad(): + with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"): + if is_train: loss = mlpf_loss(ygen, ypred) + for param in model.parameters(): + param.grad = None + else: + with torch.no_grad(): + loss = mlpf_loss(ygen, ypred) if is_train: loss["Total"].backward() @@ -302,6 +309,7 @@ def train_mlpf( num_epochs, patience, outdir, + dtype, start_epoch=1, lr_schedule=None, use_ray=False, @@ -348,15 +356,37 @@ def train_mlpf( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True ) as prof: with record_function("model_train"): - losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True, lr_schedule) + losses_t = train_and_valid( + rank, world_size, model, optimizer, train_loader, is_train=True, lr_schedule=lr_schedule, dtype=dtype + ) prof.export_chrome_trace("trace.json") else: losses_t = train_and_valid( - rank, world_size, model, optimizer, train_loader, True, lr_schedule, comet_experiment, comet_step_freq, epoch + rank, + world_size, + model, + optimizer, + train_loader, + is_train=True, + lr_schedule=lr_schedule, + comet_experiment=comet_experiment, + comet_step_freq=comet_step_freq, + epoch=epoch, + dtype=dtype, ) losses_v = train_and_valid( - rank, world_size, model, optimizer, valid_loader, False, None, comet_experiment, comet_step_freq, epoch + rank, + world_size, + model, + optimizer, + valid_loader, + is_train=False, + lr_schedule=None, + comet_experiment=comet_experiment, + comet_step_freq=comet_step_freq, + epoch=epoch, + dtype=dtype, ) if comet_experiment: @@ -486,8 +516,13 @@ def run(rank, world_size, config, args, outdir, logfile): """Demo function that will be passed to each gpu if (world_size > 1) else will run normally on the given device.""" pad_3d = config["conv_type"] != "gravnet" + pad_power_of_two = config["conv_type"] == "attention" and config["model"]["attention"]["attention_type"] == "flash" + use_cuda = rank != "cpu" + dtype = getattr(torch, config["dtype"]) + _logger.info("using dtype={}".format(dtype)) + if world_size > 1: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" @@ -511,6 +546,10 @@ def run(rank, world_size, config, args, outdir, logfile): with open(f"{loaddir}/model_kwargs.pkl", "rb") as f: model_kwargs = pkl.load(f) _logger.info("model_kwargs: {}".format(model_kwargs)) + + if config["conv_type"] == "attention": + model_kwargs["attention_type"] = config["model"]["attention"]["attention_type"] + model = MLPF(**model_kwargs).to(torch.device(rank)) optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"]) @@ -544,6 +583,7 @@ def run(rank, world_size, config, args, outdir, logfile): "sin_phi_mode": config["model"]["sin_phi_mode"], "cos_phi_mode": config["model"]["cos_phi_mode"], "energy_mode": config["model"]["energy_mode"], + "attention_type": config["model"]["attention"]["attention_type"], **config["model"][config["conv_type"]], } model = MLPF(**model_kwargs) @@ -599,6 +639,7 @@ def run(rank, world_size, config, args, outdir, logfile): config, use_cuda, pad_3d, + pad_power_of_two, use_ray=False, ) steps_per_epoch = len(loaders["train"]) @@ -615,6 +656,7 @@ def run(rank, world_size, config, args, outdir, logfile): config["num_epochs"], config["patience"], outdir, + dtype, start_epoch=start_epoch, lr_schedule=lr_schedule, use_ray=False, @@ -677,18 +719,20 @@ def run(rank, world_size, config, args, outdir, logfile): else: jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4) - run_predictions( - world_size, - rank, - model, - test_loader, - sample, - outdir, - jetdef, - jet_ptcut=15.0, - jet_match_dr=0.1, - dir_name=testdir_name, - ) + device_type = "cuda" if isinstance(rank, int) else "cpu" + with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"): + run_predictions( + world_size, + rank, + model, + test_loader, + sample, + outdir, + jetdef, + jet_ptcut=15.0, + jet_match_dr=0.1, + dir_name=testdir_name, + ) if (rank == 0) or (rank == "cpu"): # make plots and export to onnx only on a single machine if args.make_plots: @@ -700,13 +744,15 @@ def run(rank, world_size, config, args, outdir, logfile): if args.export_onnx: try: - dummy_features = torch.randn(1, 640, model_kwargs["input_dim"], device=rank) - dummy_mask = torch.zeros(1, 640, dtype=torch.bool, device=rank) + dummy_features = torch.randn(1, 8192, model_kwargs["input_dim"], device=rank) + dummy_mask = torch.zeros(1, 8192, dtype=torch.bool, device=rank) + + # Torch ONNX export in the old way torch.onnx.export( model, (dummy_features, dummy_mask), "test.onnx", - verbose=True, + verbose=False, input_names=["features", "mask"], output_names=["id", "momentum", "charge"], dynamic_axes={ @@ -717,6 +763,10 @@ def run(rank, world_size, config, args, outdir, logfile): "charge": [0, 1], }, ) + + # Torch ONNX export in the new way + # onnx_program = torch.onnx.dynamo_export(model, (dummy_features, dummy_mask)) + # onnx_program.save("test.onnx") except Exception as e: print("ONNX export failed: {}".format(e)) @@ -730,6 +780,10 @@ def override_config(config, args): arg_value = getattr(args, arg) if arg_value is not None: config[arg] = arg_value + + if not (args.attention_type is None): + config["model"]["attention"]["attention_type"] = args.attention_type + return config @@ -777,6 +831,7 @@ def train_ray_trial(config, args, outdir=None): outdir = ray.train.get_context().get_trial_dir() pad_3d = config["conv_type"] != "gravnet" + pad_power_of_two = config["conv_type"] == "attention" and config["model"]["attention"]["attention_type"] == "flash" use_cuda = True rank = ray.train.get_context().get_local_rank() @@ -789,6 +844,7 @@ def train_ray_trial(config, args, outdir=None): **config["model"][config["conv_type"]], } model = MLPF(**model_kwargs) + if world_size > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # optimizer should be created after distributing the model to devices with ray.train.torch.prepare_model(model) @@ -809,7 +865,7 @@ def train_ray_trial(config, args, outdir=None): _logger.info("Creating experiment dir {}".format(outdir)) _logger.info(f"Model directory {outdir}", color="bold") - loaders = get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, use_ray=True) + loaders = get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, pad_power_of_two, use_ray=True) if args.comet: comet_experiment = create_comet_experiment( diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index 6feddcc6d..a0fa2b0ad 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -71,6 +71,20 @@ parser.add_argument("--comet-step-freq", type=int, default=None, help="step frequency for saving comet metrics") parser.add_argument("--experiments-dir", type=str, default=None, help="base directory within which trainings are stored") parser.add_argument("--pipeline", action="store_true", default=None, help="test is running in pipeline") +parser.add_argument( + "--dtype", + type=str, + default=None, + help="data type for training", + choices=["float32", "float16", "bfloat16"], +) +parser.add_argument( + "--attention-type", + type=str, + default=None, + help="attention type for self-attention layer", + choices=["math", "efficient", "flash"], +) def main(): diff --git a/parameters/pytorch/pyg-clic-hits.yaml b/parameters/pytorch/pyg-clic-hits.yaml index b1542fa4b..38279764b 100644 --- a/parameters/pytorch/pyg-clic-hits.yaml +++ b/parameters/pytorch/pyg-clic-hits.yaml @@ -19,6 +19,7 @@ checkpoint_freq: comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 +dtype: float32 model: pt_mode: linear @@ -63,6 +64,7 @@ model: activation: "elu" # attention specific paramters num_heads: 2 + attention_type: flash mamba: conv_type: mamba diff --git a/parameters/pytorch/pyg-clic.yaml b/parameters/pytorch/pyg-clic.yaml index 54a64cde5..8452a216d 100644 --- a/parameters/pytorch/pyg-clic.yaml +++ b/parameters/pytorch/pyg-clic.yaml @@ -20,6 +20,7 @@ checkpoint_freq: comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 +dtype: float32 model: pt_mode: linear @@ -64,6 +65,7 @@ model: activation: "elu" # attention specific paramters num_heads: 2 + attention_type: flash mamba: conv_type: mamba diff --git a/parameters/pytorch/pyg-cms.yaml b/parameters/pytorch/pyg-cms.yaml index f8890001c..7ee03154d 100644 --- a/parameters/pytorch/pyg-cms.yaml +++ b/parameters/pytorch/pyg-cms.yaml @@ -6,9 +6,9 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: -num_epochs: 50 +num_epochs: 10 patience: 20 -lr: 0.001 +lr: 0.0005 lr_schedule: cosinedecay # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: @@ -20,6 +20,7 @@ checkpoint_freq: comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 +dtype: bfloat16 model: pt_mode: linear @@ -60,10 +61,11 @@ model: embedding_dim: 256 width: 256 num_convs: 3 - dropout: 0.0 + dropout: 0.3 activation: "elu" # attention specific paramters - num_heads: 2 + num_heads: 16 + attention_type: flash mamba: conv_type: mamba diff --git a/parameters/pytorch/pyg-delphes.yaml b/parameters/pytorch/pyg-delphes.yaml index 84797fc47..6746bcda6 100644 --- a/parameters/pytorch/pyg-delphes.yaml +++ b/parameters/pytorch/pyg-delphes.yaml @@ -20,6 +20,7 @@ checkpoint_freq: comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 +dtype: float32 model: pt_mode: linear @@ -64,6 +65,7 @@ model: activation: "elu" # attention specific paramters num_heads: 2 + attention_type: flash mamba: conv_type: mamba diff --git a/scripts/local_test_pyg.sh b/scripts/local_test_pyg.sh index b9df0c145..877bbe7bc 100755 --- a/scripts/local_test_pyg.sh +++ b/scripts/local_test_pyg.sh @@ -28,10 +28,10 @@ mkdir -p experiments tfds build mlpf/heptfds/cms_pf/ttbar --manual_dir ./local_test_data #test gravnet -python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type gravnet --pipeline +python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type gravnet --pipeline --dtype float32 #test transformer -python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type attention --pipeline +python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type attention --pipeline --dtype float32 --attention-type math #test GNN-LSH with export -python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type gnn_lsh --export-onnx --pipeline +python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type gnn_lsh --export-onnx --pipeline --dtype float32 diff --git a/scripts/tallinn/a100/pytorch.sh b/scripts/tallinn/a100/pytorch.sh index 43565263f..215508c40 100755 --- a/scripts/tallinn/a100/pytorch.sh +++ b/scripts/tallinn/a100/pytorch.sh @@ -1,16 +1,23 @@ #!/bin/bash #SBATCH --partition gpu #SBATCH --gres gpu:a100:1 -#SBATCH --mem-per-gpu 40G +#SBATCH --mem-per-gpu 80G #SBATCH -o logs/slurm-%x-%j-%N.out -IMG=/home/software/singularity/pytorch.simg:2023-12-06 +IMG=/home/software/singularity/pytorch.simg:2024-02-05 cd ~/particleflow -#TF training +#pytorch training singularity exec -B /scratch/persistent --nv \ --env PYTHONPATH=hep_tfds \ $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms.yaml \ - --train --conv-type gnn_lsh --num-epochs 20 --gpu-batch-multiplier 4 --num-workers 1 --prefetch-factor 10 --ntrain 10000 --nvalid 10000 + --train --conv-type attention --num-epochs 10 --gpu-batch-multiplier 40 --num-workers 2 --prefetch-factor 20 +# --train --conv-type gnn_lsh --num-epochs 20 --gpu-batch-multiplier 4 --num-workers 1 --prefetch-factor 10 --ntrain 10000 --nvalid 10000 # --train --conv-type mamba --num-epochs 20 --gpu-batch-multiplier 10 --num-workers 1 --prefetch-factor 10 --ntrain 10000 --nvalid 10000 + +# singularity exec -B /scratch/persistent --nv \ +# --env PYTHONPATH=hep_tfds \ +# $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ +# --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms.yaml \ +# --test --make-plots --conv-type attention --gpu-batch-multiplier 10 --num-workers 1 --prefetch-factor 10 --load experiments/pyg-cms_20240204_183048_293390/sub1/best_weights.pth --ntest 1000 diff --git a/scripts/tallinn/rtx/pytorch.sh b/scripts/tallinn/rtx/pytorch.sh index a8f50d7a7..08f359894 100755 --- a/scripts/tallinn/rtx/pytorch.sh +++ b/scripts/tallinn/rtx/pytorch.sh @@ -35,4 +35,4 @@ IMG=/home/software/singularity/pytorch.simg:2023-12-06 # --env PYTHONPATH=hep_tfds \ # $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ # --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms.yaml \ -# --test --make-plots --conv-type mamba --gpu-batch-multiplier 5 --num-workers 1 --prefetch-factor 10 --load experiments/pyg-cms_20240126_221457_189384/sub1/best_weights.pth --ntest 1000 +# --test --make-plots --conv-type attention --gpu-batch-multiplier 10 --num-workers 1 --prefetch-factor 10 --load experiments/pyg-cms_20240204_183048_293390/sub1/best_weights.pth --ntest 1000