Skip to content

Commit

Permalink
enable FlashAttention in pytorch, update to torch 2.2.0 (jpata#292)
Browse files Browse the repository at this point in the history
* implement attention configuration
  • Loading branch information
jpata authored Feb 13, 2024
1 parent 703c57c commit af3225c
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 58 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ jobs:
python-version: "3.10.12"
cache: "pip"
- run: pip install -r requirements.txt
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: pip3 install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-2.2.0+cpu.html

tf-unittests:
runs-on: ubuntu-22.04
Expand Down Expand Up @@ -101,8 +101,8 @@ jobs:
python-version: "3.10.12"
cache: "pip"
- run: pip install -r requirements.txt
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: pip3 install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-2.2.0+cpu.html
- run: PYTHONPATH=. python3 -m unittest tests/test_torch_and_tf.py

pyg-pipeline:
Expand All @@ -115,6 +115,6 @@ jobs:
python-version: "3.10.12"
cache: "pip"
- run: pip install -r requirements.txt
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: pip3 install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-2.2.0+cpu.html
- run: ./scripts/local_test_pyg.sh
23 changes: 18 additions & 5 deletions mlpf/pyg/PFDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,19 @@ def __init__(
)


def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()


class Collater:
"""Based on the Collater found on torch_geometric docs we build our own."""

def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_bin_size=640, pad_3d=True):
def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_3d=True, pad_power_of_two=True):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
self.keys_to_get = keys_to_get
self.pad_bin_size = pad_bin_size
self.pad_3d = pad_3d
self.pad_power_of_two = False

def __call__(self, inputs):
num_samples_in_batch = len(inputs)
Expand All @@ -129,7 +133,16 @@ def __call__(self, inputs):
if not self.pad_3d:
return ret
else:
ret = {k: torch_geometric.utils.to_dense_batch(getattr(ret, k), ret.batch) for k in elem_keys}
# pad to closest power of two
if self.pad_power_of_two:
sizes = [next_power_of_2(len(b.X)) for b in batch]
max_size = max(sizes)
else:
max_size = None
ret = {
k: torch_geometric.utils.to_dense_batch(getattr(ret, k), ret.batch, max_num_nodes=max_size)
for k in elem_keys
}

ret["mask"] = ret["X"][1]

Expand Down Expand Up @@ -185,7 +198,7 @@ def __len__(self):
return len_


def get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, use_ray):
def get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, pad_power_of_two, use_ray):
loaders = {}
for split in ["train", "valid"]: # build train, valid dataset and dataloaders
loaders[split] = []
Expand Down Expand Up @@ -219,7 +232,7 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, pad_3d, use_
loader = PFDataLoader(
dataset,
batch_size=batch_size,
collate_fn=Collater(["X", "ygen"], pad_3d=pad_3d),
collate_fn=Collater(["X", "ygen"], pad_3d=pad_3d, pad_power_of_two=pad_power_of_two),
sampler=sampler,
num_workers=config["num_workers"],
prefetch_factor=config["prefetch_factor"],
Expand Down
2 changes: 1 addition & 1 deletion mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import mplhep
import numpy as np
import torch
import torch_geometric
import tqdm
import vector
from jet_utils import build_dummy_array, match_two_jet_collections
Expand All @@ -22,6 +21,7 @@
plot_particles,
plot_sum_energy,
)
import torch_geometric
from torch_geometric.data import Batch

from .logger import _logger
Expand Down
28 changes: 24 additions & 4 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from .gnn_lsh import CombinedGraphLayer

from torch.backends.cuda import sdp_kernel
from pyg.logger import _logger


class GravNetLayer(nn.Module):
def __init__(self, embedding_dim, space_dimensions, propagate_dimensions, k, dropout):
Expand All @@ -22,7 +25,7 @@ def forward(self, x, batch_index):


class SelfAttentionLayer(nn.Module):
def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1):
def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1, attention_type="efficient"):
super(SelfAttentionLayer, self).__init__()
self.act = nn.ELU
self.mha = torch.nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True)
Expand All @@ -32,9 +35,20 @@ def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1):
nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act()
)
self.dropout = torch.nn.Dropout(dropout)
self.attention_type = attention_type
_logger.info("using attention_type={}".format(attention_type))
self.attn_params = {
"math": {"enable_math": True, "enable_mem_efficient": False, "enable_flash": False},
"efficient": {"enable_math": False, "enable_mem_efficient": True, "enable_flash": False},
"flash": {"enable_math": False, "enable_mem_efficient": False, "enable_flash": True},
}

def forward(self, x, mask):
x = self.norm0(x + self.mha(x, x, x, key_padding_mask=mask, need_weights=False)[0])
# explicitly call the desired attention mechanism
with sdp_kernel(**self.attn_params[self.attention_type]):
mha_out = self.mha(x, x, x, need_weights=False)[0]

x = self.norm0(x + mha_out)
x = self.norm1(x + self.seq(x))
x = self.dropout(x)
x = x * (~mask.unsqueeze(-1))
Expand Down Expand Up @@ -117,6 +131,7 @@ def __init__(
propagate_dimensions=32,
space_dimensions=4,
conv_type="gravnet",
attention_type="flash",
# gnn-lsh specific parameters
bin_size=640,
max_num_bins=200,
Expand Down Expand Up @@ -168,8 +183,12 @@ def __init__(
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout))
self.conv_reg.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout))
self.conv_id.append(
SelfAttentionLayer(embedding_dim, num_heads, width, dropout, attention_type=attention_type)
)
self.conv_reg.append(
SelfAttentionLayer(embedding_dim, num_heads, width, dropout, attention_type=attention_type)
)
elif self.conv_type == "mamba":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
Expand Down Expand Up @@ -209,6 +228,7 @@ def __init__(
# elementwise DNN for node charge regression, classes (-1, 0, 1)
self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout)

# @torch.compile
def forward(self, X_features, batch_or_mask):
embeddings_id, embeddings_reg = [], []
if self.num_convs != 0:
Expand Down
Loading

0 comments on commit af3225c

Please sign in to comment.