Skip to content

Commit

Permalink
add an improvised magnitude preserving image transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 7, 2024
1 parent 8ff6835 commit fd5abb9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
56 changes: 56 additions & 0 deletions denoising_diffusion_pytorch/karras_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,62 @@ def forward(

return self.output_block(x)

# improvised MP Transformer

class MPFeedForward(Module):
def __init__(
self,
*,
dim,
mult = 4,
mp_add_t = 0.3
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
PixelNorm(dim = 1),
Conv2d(dim, dim_inner, 1),
MPSiLU(),
Conv2d(dim_inner, dim, 1)
)

self.mp_add = MPAdd(t = mp_add_t)

def forward(self, x):
res = x
out = self.net(x)
return self.mp_add(out, res)

class MPImageTransformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_mem_kv = 4,
ff_mult = 4,
attn_flash = False,
residual_mp_add_t = 0.3
):
super().__init__()
self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),
MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)
]))

def forward(self, x):

for attn, ff in self.layers:
x = attn(x)
x = ff(x)

return x

# example

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.3'
__version__ = '1.10.5'

0 comments on commit fd5abb9

Please sign in to comment.