From 8f9e00531a76be6886895f97451b31961ba97bc8 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 3 Oct 2023 01:57:38 -0400 Subject: [PATCH] new verison --- mega_vit/main.py | 27 +++++++++++++++++++-------- pyproject.toml | 2 +- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mega_vit/main.py b/mega_vit/main.py index 15e1273..bb3ac2a 100644 --- a/mega_vit/main.py +++ b/mega_vit/main.py @@ -2,7 +2,7 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import nn - +import torch.nn.functional as F # helpers def pair(t): @@ -78,18 +78,29 @@ def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) - # #normalize key and values, QK Normalization + # #normalize key and values or known QK Normalization k = self.norm_k(k) v = self.norm_v(v) - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + # should this be replaced? + # dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + # attn = self.attend(dots) + # attn = self.dropout(attn) + # out = torch.matmul(attn, v) + + # attn + with torch.backends.cuda.sdp_kernel(enable_math=True): + #attention + out = F.scaled_dot_product_attention(q, k, v) + + #dropout + out = self.dropout(out) - attn = self.attend(dots) - attn = self.dropout(attn) + #rearrange to original shape + out = rearrange(out, 'b h n d -> b n (h d)') - out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) + #project out + return self.to_out(out) class Transformer(nn.Module): def __init__( diff --git a/pyproject.toml b/pyproject.toml index aa0243b..60cfc3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "mega-vit" -version = "0.0.3" +version = "0.0.4" description = "mega-vit - Pytorch" license = "MIT" authors = ["Kye Gomez "]