Skip to content

Commit

Permalink
parallel, and layernorm before attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 2, 2023
1 parent 42269a3 commit e31d66e
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions mega_vit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
Expand All @@ -48,6 +54,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.norm_k = nn.LayerNorm(dim)
self.norm_v = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
Expand All @@ -65,6 +73,10 @@ def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

# #normalize key and values, QK Normalization
# k = self.norm_k(k)
# v = self.norm_v(v)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
Expand All @@ -87,9 +99,12 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x

#layernorm before attention
x = self.norm(x)

#parallel
x = x + attn(x) + ff(x)

return self.norm(x)

class ViT(nn.Module):
Expand Down Expand Up @@ -155,21 +170,41 @@ def forward(self, img):



# import torch

# v = ViT(
# image_size = 224,
# patch_size = 14,
# num_classes = 1000,
# dim = 6144,
# depth = 48,
# heads = 48,
# mlp_dim = 2048,
# dropout = 0.1,
# emb_dropout = 0.1
# )

# img = torch.randn(1, 3, 224, 224)

# preds = v(img) # (1, 1000)
# print(preds)

import torch
# from vit_pytorch import ViT

v = ViT(
image_size = 224,
patch_size = 14,
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 6144,
depth = 48,
heads = 48,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)
img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)
print(preds)

0 comments on commit e31d66e

Please sign in to comment.