diff --git a/mega_vit/main.py b/mega_vit/main.py index 53f84ca..f81092f 100644 --- a/mega_vit/main.py +++ b/mega_vit/main.py @@ -1,163 +1,14 @@ -from collections import namedtuple -from functools import wraps - import torch -import torch.nn.functional as F +from torch import nn from einops import rearrange, repeat from einops.layers.torch import Rearrange -from packaging import version -from torch import einsum, nn # helpers def pair(t): return t if isinstance(t, tuple) else (t, t) -# constants - -Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) - -# helpers - -def exists(val): - return val is not None - -def once(fn): - called = False - @wraps(fn) - def inner(x): - nonlocal called - if called: - return - called = True - return fn(x) - return inner - -print_once = once(print) - -# main class - -class Attention(nn.Module): - def __init__( - self, - dropout = 0., - causal = False, - use_flash_attn = False - ): - super().__init__() - self.dropout = dropout - self.attn_dropout = nn.Dropout(dropout) - - self.causal = causal - self.register_buffer("mask", None, persistent=False) - - self.use_flash_attn = use_flash_attn - assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' - - # determine efficient attention configs for cuda and cpu - - self.cpu_config = Config(True, True, True) - self.cuda_config = None - - if not torch.cuda.is_available() or not use_flash_attn: - return - - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) - - if device_properties.major == 8 and device_properties.minor == 0: - print_once('A100 GPU detected, using flash attention if input tensor is on cuda') - self.cuda_config = Config(True, False, False) - else: - print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') - self.cuda_config = Config(False, True, True) - - def get_mask(self, n, device): - if exists(self.mask) and self.mask.shape[-1] >= n: - return self.mask[:n, :n] - - mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) - self.register_buffer("mask", mask, persistent=False) - return mask - - def flash_attn(self, q, k, v, mask = None): - _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda - - # Recommended for multi-query single-key-value attention by Tri Dao - # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) - - k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) - v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) - - k = LayerNorm(k.shape[-1])(k) - v = LayerNorm(v.shape[-2])(v) - - - # Check if mask exists and expand to compatible shape - # The mask is B L, so it would have to be expanded to B H N L - - if exists(mask): - mask = rearrange(mask, 'b j -> b 1 1 j') - mask = mask.expand(-1, heads, q_len, -1) - - # Check if there is a compatible device for flash attention - - config = self.cuda_config if is_cuda else self.cpu_config - - # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale - - with torch.backends.cuda.sdp_kernel(**config._asdict()): - out = F.scaled_dot_product_attention( - q, k, v, - attn_mask = mask, - dropout_p = self.dropout if self.training else 0., - is_causal = self.causal - ) - - return out - - def forward(self, q, k, v, mask = None): - """ - einstein notation - b - batch - h - heads - n, i, j - sequence length (base sequence length, source, target) - d - feature dimension - """ - - n, device = q.shape[-2], q.device - - scale = q.shape[-1] ** -0.5 - - if self.use_flash_attn: - return self.flash_attn(q, k, v, mask = mask) - - # similarity - - sim = einsum("b h i d, b j d -> b h i j", q, k) * scale - - # key padding mask - - if exists(mask): - mask = rearrange(mask, 'b j -> b 1 1 j') - sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) - - # causal mask - - if self.causal: - causal_mask = self.get_mask(n, device) - sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) - - # attention - - attn = sim.softmax(dim=-1) - attn = self.attn_dropout(attn) - - # aggregate values - - out = einsum("b h i j, b j d -> b h i d", attn, v) - - return out - +# classes class LayerNorm(nn.Module): def __init__( @@ -172,7 +23,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layer_norm(x) - class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() @@ -188,7 +38,41 @@ def __init__(self, dim, hidden_dim, dropout = 0.): def forward(self, x): return self.net(x) +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): @@ -270,20 +154,22 @@ def forward(self, img): return self.mlp_head(x) + import torch v = ViT( - image_size = 256, - patch_size = 32, + image_size = 224, + patch_size = 14, num_classes = 1000, - dim = 1024, - depth = 6, - heads = 16, + dim = 6144, + depth = 48, + heads = 48, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) -img = torch.randn(1, 3, 256, 256) +img = torch.randn(1, 3, 224, 224) -preds = v(img) # (1, 1000) \ No newline at end of file +preds = v(img) # (1, 1000) +print(preds) \ No newline at end of file