Skip to content

Commit

Permalink
norm qk values
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 2, 2023
1 parent 612fe42 commit f3f7ab2
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions mega_vit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.norm_k = nn.LayerNorm(dim)
self.norm_v = nn.LayerNorm(dim)
self.norm_k = nn.LayerNorm(dim_head)
self.norm_v = nn.LayerNorm(dim_head)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
Expand All @@ -79,8 +79,8 @@ def forward(self, x):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

# #normalize key and values, QK Normalization
# k = self.norm_k(k)
# v = self.norm_v(v)
k = self.norm_k(k)
v = self.norm_v(v)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

Expand Down Expand Up @@ -148,7 +148,11 @@ def __init__(
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
Rearrange(
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1 = patch_height,
p2 = patch_width
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
Expand Down

0 comments on commit f3f7ab2

Please sign in to comment.