Skip to content

Commit

Permalink
code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 2, 2023
1 parent 4ef4f19 commit 612fe42
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions mega_vit/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn

# helpers

Expand All @@ -24,7 +24,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer_norm(x)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
def __init__(
self,
dim,
hidden_dim,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
Expand Down Expand Up @@ -87,7 +92,15 @@ def forward(self, x):
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
Expand Down

0 comments on commit 612fe42

Please sign in to comment.