Skip to content

Commit

Permalink
new verison
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 3, 2023
1 parent ffa44a1 commit 8f9e005
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
27 changes: 19 additions & 8 deletions mega_vit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn

import torch.nn.functional as F
# helpers

def pair(t):
Expand Down Expand Up @@ -78,18 +78,29 @@ def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

# #normalize key and values, QK Normalization
# #normalize key and values or known QK Normalization
k = self.norm_k(k)
v = self.norm_v(v)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# should this be replaced?
# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attn = self.attend(dots)
# attn = self.dropout(attn)
# out = torch.matmul(attn, v)

# attn
with torch.backends.cuda.sdp_kernel(enable_math=True):
#attention
out = F.scaled_dot_product_attention(q, k, v)

#dropout
out = self.dropout(out)

attn = self.attend(dots)
attn = self.dropout(attn)
#rearrange to original shape
out = rearrange(out, 'b h n d -> b n (h d)')

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
#project out
return self.to_out(out)

class Transformer(nn.Module):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "mega-vit"
version = "0.0.3"
version = "0.0.4"
description = "mega-vit - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand Down

0 comments on commit 8f9e005

Please sign in to comment.