diff --git a/mega_vit/main.py b/mega_vit/main.py index bb3ac2a..361df3c 100644 --- a/mega_vit/main.py +++ b/mega_vit/main.py @@ -92,6 +92,9 @@ def forward(self, x): with torch.backends.cuda.sdp_kernel(enable_math=True): #attention out = F.scaled_dot_product_attention(q, k, v) + + #softmax + out = self.attend(out) #dropout out = self.dropout(out)