diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 7520404c1..f25cba636 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -624,6 +624,62 @@ def forward( return self.output_block(x) +# improvised MP Transformer + +class MPFeedForward(Module): + def __init__( + self, + *, + dim, + mult = 4, + mp_add_t = 0.3 + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + PixelNorm(dim = 1), + Conv2d(dim, dim_inner, 1), + MPSiLU(), + Conv2d(dim_inner, dim, 1) + ) + + self.mp_add = MPAdd(t = mp_add_t) + + def forward(self, x): + res = x + out = self.net(x) + return self.mp_add(out, res) + +class MPImageTransformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head = 64, + heads = 8, + num_mem_kv = 4, + ff_mult = 4, + attn_flash = False, + residual_mp_add_t = 0.3 + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + self.layers.append(ModuleList([ + Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t), + MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t) + ])) + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x = ff(x) + + return x + # example if __name__ == '__main__': diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 78a6e51e6..420b4d0ed 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.3' +__version__ = '1.10.5'