diff --git a/mega_vit/main.py b/mega_vit/main.py index f81092f..3af7342 100644 --- a/mega_vit/main.py +++ b/mega_vit/main.py @@ -39,7 +39,13 @@ def forward(self, x): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + def __init__( + self, + dim, + heads = 8, + dim_head = 64, + dropout = 0. + ): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) @@ -48,6 +54,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): self.scale = dim_head ** -0.5 self.norm = nn.LayerNorm(dim) + self.norm_k = nn.LayerNorm(dim) + self.norm_v = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -65,6 +73,10 @@ def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + # #normalize key and values, QK Normalization + # k = self.norm_k(k) + # v = self.norm_v(v) + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) @@ -87,9 +99,12 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): def forward(self, x): for attn, ff in self.layers: - x = attn(x) + x - x = ff(x) + x - + #layernorm before attention + x = self.norm(x) + + #parallel + x = x + attn(x) + ff(x) + return self.norm(x) class ViT(nn.Module): @@ -155,21 +170,41 @@ def forward(self, img): +# import torch + +# v = ViT( +# image_size = 224, +# patch_size = 14, +# num_classes = 1000, +# dim = 6144, +# depth = 48, +# heads = 48, +# mlp_dim = 2048, +# dropout = 0.1, +# emb_dropout = 0.1 +# ) + +# img = torch.randn(1, 3, 224, 224) + +# preds = v(img) # (1, 1000) +# print(preds) + import torch +# from vit_pytorch import ViT v = ViT( - image_size = 224, - patch_size = 14, + image_size = 256, + patch_size = 32, num_classes = 1000, - dim = 6144, - depth = 48, - heads = 48, + dim = 1024, + depth = 6, + heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) -img = torch.randn(1, 3, 224, 224) +img = torch.randn(1, 3, 256, 256) preds = v(img) # (1, 1000) print(preds) \ No newline at end of file