Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 2, 2023
1 parent eba4abf commit 42269a3
Showing 1 changed file with 45 additions and 159 deletions.
204 changes: 45 additions & 159 deletions mega_vit/main.py
Original file line number Diff line number Diff line change
@@ -1,163 +1,14 @@
from collections import namedtuple
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from packaging import version
from torch import einsum, nn

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

# constants

Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def exists(val):
return val is not None

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# main class

class Attention(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
use_flash_attn = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.causal = causal
self.register_buffer("mask", None, persistent=False)

self.use_flash_attn = use_flash_attn
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = Config(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not use_flash_attn:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)

def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]

mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask

def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

k = LayerNorm(k.shape[-1])(k)
v = LayerNorm(v.shape[-2])(v)


# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)

return out

def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

n, device = q.shape[-2], q.device

scale = q.shape[-1] ** -0.5

if self.use_flash_attn:
return self.flash_attn(q, k, v, mask = mask)

# similarity

sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

# key padding mask

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# causal mask

if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)

# aggregate values

out = einsum("b h i j, b j d -> b h i d", attn, v)

return out

# classes

class LayerNorm(nn.Module):
def __init__(
Expand All @@ -172,7 +23,6 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer_norm(x)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
Expand All @@ -188,7 +38,41 @@ def __init__(self, dim, hidden_dim, dropout = 0.):
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
Expand Down Expand Up @@ -270,20 +154,22 @@ def forward(self, img):
return self.mlp_head(x)



import torch

v = ViT(
image_size = 256,
patch_size = 32,
image_size = 224,
patch_size = 14,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
dim = 6144,
depth = 48,
heads = 48,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)
img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)
preds = v(img) # (1, 1000)
print(preds)

0 comments on commit 42269a3

Please sign in to comment.